You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/kv_cache/kvcache_manager.py

603 lines
28 KiB

from typing import List, Tuple
import torch
from transformers.configuration_utils import PretrainedConfig
from colossalai.inference.config import InferenceConfig
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .block_cache import CacheBlock
__all__ = ["KVCacheManager"]
GIGABYTE = 1024**3
class KVCacheManager:
"""KVCacheManager manages both the logical cache blocks and physical KV cache (tensors).
NOTE: The KVCacheManager is designed to be interacted with indices of logical blocks.
That is, it won't allocate and return a physical cache to the engine or scheduler;
instead, it will mark the logical block as allocated and update the block id representing
the physical cache to the caller. The physical cache is actually used and updated in kernels.
Example
A block table of a single sequence before block allocation might be:
| -1 | -1 | -1 | -1 | -1 | -1 |
where the maximum blocks per sequence is 6
The block table after block allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
Then the logical blocks with id 0, 1, and 2, are allocated for this sequence,
and the physical caches, each with size of `block_size * kv_head_num * head_size * elem_size` for a single layer,
corresponding to these blocks will be used to read/write KV Caches in kernels.
For a batch of sequences, the block tables after allocation might be:
| 0 | 1 | 2 | -1 | -1 | -1 |
| 3 | 4 | 5 | 6 | 7 | -1 |
| 8 | 9 | 10 | 11 | -1 | -1 |
| 12 | 13 | 14 | 15 | -1 | -1 |
where 16 logical cache blocks are allocated and the same number of physical cache blocks will be used in kernels.
Currently, allocations and updates are done at granularity of a single sequence.
That is, the block table should be a 1D tensor of shape [max_blocks_per_sequence].
And it's possible to have a batch of sequences with different lengths of block tables.
"""
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()
# Parallel settings
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
if config.kv_cache_dtype is None:
self.kv_cache_dtype = config.dtype
else:
self.kv_cache_dtype = config.kv_cache_dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = model_config.num_hidden_layers
self.head_num = model_config.num_attention_heads
self.head_size = model_config.hidden_size // self.head_num
if hasattr(model_config, "num_key_value_heads"):
self.kv_head_num = model_config.num_key_value_heads
else:
self.kv_head_num = self.head_num
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Physical cache allocation
if config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
* 2
* self.num_blocks
* self.block_size
* self.kv_head_num
* self.head_size
)
self.logger.info(
f"Allocated {self.total_physical_cache_size_in_bytes / GIGABYTE:.2f} GB of KV cache on device {self.device}."
)
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
# block availablity state 0->allocated, 1->free
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
@property
def total_num_blocks(self) -> int:
"""Get the total number of logical cache blocks."""
return self.num_blocks
@property
def num_available_blocks(self) -> int:
"""Get the number of available cache blocks."""
return self._available_blocks
def get_head_size(self):
return self.head_size
def get_kv_cache(self):
"""Get k_cache and v_cache"""
return self._kv_caches
def get_max_blocks_per_sequence(self) -> int:
"""Get the maximum number of blocks that can be allocated for a single sequence."""
# TODO Consider removing this function as we plan to implement "half-dynamic" batching in schduler/request handler,
# which will make the max_blocks_per_sequence dynamic based on the prompt lengths of sequences
# in the current batch.
return self.max_blocks_per_sequence
def check_allocation(self, seq: Sequence) -> bool:
num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size
return num_blocks_needed <= self.num_available_blocks
def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]:
"""Get the key and value pointers of physical caches (of specific layer) corresponding to a logical cache block."""
block: CacheBlock = self._cache_blocks[block_id]
return block.k_ptrs[layer_id], block.v_ptrs[layer_id]
def get_block_table_kv_ptrs(self, block_table: torch.Tensor, layer_id: int) -> Tuple[int, int]:
"""Get the key and value pointers of physical caches (of specific layer) corresponding to logical cache blocks indicated by the block table."""
k_ptrs = []
v_ptrs = []
for block_id in block_table:
if block_id >= 0:
block: CacheBlock = self._cache_blocks[block_id]
k_ptrs.append(block.k_ptrs[layer_id])
v_ptrs.append(block.v_ptrs[layer_id])
return k_ptrs, v_ptrs
def allocate_context_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
"""Allocate the logical cache blocks for a single sequence during prefill stage,
and updates the provided block table with the allocated block ids.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece.
"""
assert block_table.dim() == 1
if not torch.all(block_table < 0):
self.logger.error("Some slots on provided block table have been allocated.")
blocks_required = (context_len + self.block_size - 1) // self.block_size
if blocks_required > self._available_blocks:
self.logger.warning(
f"No enough blocks to allocate. Available blocks {self._available_blocks}; context length {context_len}."
)
return
# Try contiguous allocation
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
torch.subtract(
self._block_states_cum[blocks_required:],
self._block_states_cum[:-blocks_required],
out=self._block_finder[blocks_required - 1 :],
)
end_indexes = torch.nonzero(self._block_finder == blocks_required, as_tuple=False).view(-1)
if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - blocks_required # closed interval
block_indexes = torch.arange(start_idx, end_idx, device=block_table.device)
else:
# non-contiguous cache
available_block_indexes = torch.nonzero(self._block_states == 0).view(-1)
block_indexes = available_block_indexes[:blocks_required]
# Update block table
block_table[:blocks_required] = block_indexes
# Update cache blocks
self._block_states[block_indexes] = 0
self._available_blocks -= blocks_required
for block_id in block_indexes.tolist():
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
if block_id == block_indexes[-1].item():
self._allocate_on_block(
block, block.block_size if context_len % block.block_size == 0 else context_len % block.block_size
)
else:
self._allocate_on_block(block, block.block_size)
def allocate_context_from_block_tables(self, block_tables: torch.Tensor, context_lengths: torch.Tensor) -> None:
"""Allocate logical cache blocks for a batch of sequences during prefill stage.
Args:
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
context_lengths (torch.Tensor): [bsz]]
"""
assert block_tables.dim() == 2
assert block_tables.size(0) == context_lengths.size(0)
if not torch.all(block_tables < 0):
self.logger.error("Some slots on provided block table have been allocated.")
blocks_required = (context_lengths + self.block_size - 1) // self.block_size
num_blocks_required = torch.sum(blocks_required).item()
assert isinstance(num_blocks_required, int)
if num_blocks_required > self._available_blocks:
self.logger.warning(
f"Lacking blocks to allocate. Available blocks {self._available_blocks}; blocks asked {num_blocks_required}."
)
return
bsz = block_tables.size(0)
# Try contiguous allocation
torch.cumsum(self._block_states, dim=-1, out=self._block_states_cum[1:])
torch.subtract(
self._block_states_cum[num_blocks_required:],
self._block_states_cum[:-num_blocks_required],
out=self._block_finder[num_blocks_required - 1 :],
)
end_indexes = torch.nonzero(self._block_finder == num_blocks_required, as_tuple=False).view(-1)
if end_indexes.numel() > 0:
# contiguous cache exists
end_idx = end_indexes[0].item() + 1 # open interval
start_idx = end_idx - num_blocks_required # closed interval
alloc_block_ids = torch.arange(start_idx, end_idx)
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = torch.arange(
start_idx, start_idx + curr_required, device=block_tables.device
)
start_idx += curr_required
else:
# non-contiguous cache
available_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = available_block_ids[:num_blocks_required]
alloc_block_ids = alloc_block_ids.to(dtype=block_tables.dtype, device=block_tables.device)
start_idx = 0
for i in range(bsz):
curr_required = blocks_required[i]
block_tables[i, :curr_required] = alloc_block_ids[start_idx, start_idx + curr_required]
start_idx += curr_required
# Update cache blocks
self._block_states[alloc_block_ids] = 0
self._available_blocks -= num_blocks_required
last_block_locs = torch.cumsum(blocks_required, dim=0) - 1
last_block_locs = last_block_locs.to(device=alloc_block_ids.device)
for i, block_id in enumerate(alloc_block_ids[last_block_locs]):
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(
block,
block.block_size
if context_lengths[i] % block.block_size == 0
else context_lengths[i].item() % block.block_size,
)
for block_id in alloc_block_ids:
if block_id in alloc_block_ids[last_block_locs]:
continue
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._allocate_on_block(block, block.block_size)
def allocate_token_from_block_table(self, block_table: torch.Tensor, context_len: int) -> None:
"""Allocate the logical cache block for a single sequence during decoding stage,
and updates the provided block table if a new cache block is needed.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
context_len: The length of the processing sequnece (already-allocated length).
"""
assert block_table.dim() == 1
# The last allocated block may be either partially or fully occupied.
# `alloc_local_block_idx` is the index of block to be allocated on provided block table.
alloc_local_block_idx = context_len // self.block_size
return self.allocate_single_block(block_table, alloc_local_block_idx)
def allocate_tokens_from_block_tables(
self, block_tables: torch.Tensor, context_lens: torch.Tensor, bsz: int = None
) -> List[int]:
"""Allocate logical cache blocks for a batch of sequences during decoding stage.
Usage:
allocate_context_from_block_tables
model forward (block tables & context lengths passed)
update context lengths
allocate_tokens_from_block_tables
model forward
update context lengths
allocate_tokens_from_block_tables
model forward
update context lengths
...
Args:
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence]
context_lengths (torch.Tensor): [bsz]
Returns:
List[int]: list of sequence uid to be recycled
"""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
alloc_local_block_indexes = (context_lens[:bsz]) // self.block_size
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
seqs_to_recycle = []
new_blocks_required = torch.sum(block_global_ids < 0).item()
seqs_req_new_blocks = torch.nonzero(block_global_ids < 0).squeeze()
if new_blocks_required > 0:
if new_blocks_required > self._available_blocks:
# TODO might want to revise the logic here
# Process the first (_available_blocks) sequences that require new blocks
# Put the rest of the sequences back to recycled
seqs_req_new_blocks, seqs_to_recycle = (
seqs_req_new_blocks[: self._available_blocks],
seqs_req_new_blocks[self._available_blocks :],
)
for seq_id in seqs_to_recycle:
self.free_block_table(block_tables[seq_id])
new_blocks_required = self._available_blocks
# NOTE might want to alloc contiguous logic
free_block_ids = torch.nonzero(self._block_states > 0).view(-1)
alloc_block_ids = free_block_ids[:new_blocks_required].to(
dtype=block_tables.dtype, device=block_tables.device
)
for block_id in alloc_block_ids:
block: CacheBlock = self._cache_blocks[block_id]
block.add_ref()
self._block_states[block_id] = 0
self._available_blocks -= 1
block_tables[seqs_req_new_blocks, alloc_local_block_indexes[seqs_req_new_blocks]] = alloc_block_ids
block_global_ids = block_tables[torch.arange(0, bsz), alloc_local_block_indexes]
for block_id in block_global_ids:
self._allocate_on_block(self._cache_blocks[block_id], 1)
return seqs_to_recycle
def allocate_n_tokens_from_block_tables(
self,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
bsz: int,
n: int,
) -> List[int]:
"""Allocate logical cache blocks for `n` new tokens for a batch of sequences during decoding stage."""
assert block_tables.dim() == 2
assert context_lens.dim() == 1
bsz = block_tables.size(0) if bsz is None else bsz
assert bsz == 1, "Support bsz 1 for now" # TODO support bsz > 1
seqs_to_recycle = []
for i in range(n):
seqs_to_recycle += self.allocate_tokens_from_block_tables(block_tables, context_lens - n + i + 1, bsz)
return seqs_to_recycle
def allocate_single_block(self, block_table: torch.Tensor, block_local_idx: int) -> int:
"""Allocate space asked on a single block in the block table, specified by the provided position id,
and updates the provided block table with the allocated block.
Args:
block_table: A 1D tensor of shape [max_blocks_per_sequence], mapping of token_position_id -> block_id.
block_local_idx: The index of the block in the block table.
space_asked: i.e. The number of tokens to be assigned space for.
Returns:
The remaining space required to be allocated (in other blocks).
"""
space_asked = 1
block_global_id = block_table[block_local_idx].item()
if block_global_id < 0:
# Allocate a new block if the current position is not assigned a block yet
if self._available_blocks <= 0:
# No available blocks to allocate, we free current sequence and return it to
self.free_block_table(block_table)
return True
free_block_id = torch.nonzero(self._block_states == 1).view(-1)[0]
block: CacheBlock = self._cache_blocks[free_block_id]
block.add_ref()
block_global_id = block.block_id
self._available_blocks -= 1
self._block_states[block_global_id] = 0
block_table[block_local_idx] = block_global_id
block: CacheBlock = self._cache_blocks[block_global_id]
return self._allocate_on_block(block, space_asked)
# only when space asked if fully satisfied, the return value will be zero.
def free_block_table(self, block_table: torch.Tensor) -> None:
"""Free the logical cache blocks for **a single sequence**."""
assert block_table.dim() == 1
for i, global_block_id in enumerate(block_table.tolist()):
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
# reset the block id in the block table (if we maintain a 2D tensors as block tables in Engine)
block_table[i] = -1
def free_block_tables(self, block_tables: torch.Tensor, first_n: int = None) -> None:
"""Release the logical cache blocks for a batch of sequences.
If `first_n` is provided, only the blocks for the first several sequences will be released.
"""
assert block_tables.dim() == 2
first_n = block_tables.size(0) if first_n is None else first_n
for block_table in block_tables[:first_n]:
self.free_block_table(block_table)
def clear_all(self) -> None:
"""Clear all the references and allocations on all the cache blocks."""
for block in self._cache_blocks:
block.clear()
self._available_blocks = self.num_blocks
self._block_states[:] = 1
def streamingllm_free_block_tables(self, updated_block_ids: List[int]):
"""
Free the block that needs to be swapped out.
"""
for global_block_id in updated_block_ids:
if global_block_id < 0:
return
block: CacheBlock = self._cache_blocks[global_block_id]
block.remove_ref()
if not block.has_ref():
block.allocated_size = 0
self._available_blocks += 1
self._block_states[global_block_id] = 1
def get_physical_cache(self, layer_id: int, block_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get the tensor corresponding to the cache block with the prompted id for a specific layer."""
return self._kv_caches[0][layer_id][block_idx], self._kv_caches[1][layer_id][block_idx]
def _allocate_on_block(self, block: CacheBlock, space_asked: int) -> int:
"""Allocate a specific size of space on a provided cache block.
Returns:
The remaining space required to be allocated (in other blocks).
"""
assert block.available_space > 0, f"Found no available space left in the chosen block {block}."
space_to_allocate = min(block.available_space, space_asked)
block.allocate(space_to_allocate)
return space_asked - space_to_allocate
def _init_logical_caches(self):
"""Initialize the logical cache blocks.
NOTE This function should be called only after the physical caches have been allocated.
The data pointers of physical caches will be binded to each logical cache block.
"""
assert self._kv_caches is not None and len(self._kv_caches[0]) > 0
blocks = []
physical_block_size = self.elem_size_in_bytes * self.block_size * self.kv_head_num * self.head_size
k_ptrs = [
self._kv_caches[0][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
v_ptrs = [
self._kv_caches[1][layer_idx].data_ptr() - physical_block_size for layer_idx in range(self.num_layers)
]
for i in range(self.num_blocks):
k_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in k_ptrs]
v_ptrs = [first_block_ptr + physical_block_size for first_block_ptr in v_ptrs]
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs, v_ptrs)
blocks.append(cache_block)
return blocks
def _init_device_caches(
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
with shape of [num_blocks, num_kv_heads, block_size, head_size]
"""
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
k_cache.append(torch.zeros(kalloc_shape, dtype=self.kv_cache_dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.kv_cache_dtype, device=self.device))
return k_cache, v_cache
class RPCKVCacheManager(KVCacheManager):
def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verbose: bool = False) -> None:
self.logger = get_dist_logger(__name__)
self.device = get_current_device()
self.config = config
# Parallel settings
self.tp_size = config.tp_size
# Model settings
self.dtype = config.dtype
self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size()
self.num_layers = model_config.num_hidden_layers
self.head_num = model_config.num_attention_heads
self.head_size = model_config.hidden_size // self.head_num
if hasattr(model_config, "num_key_value_heads"):
self.kv_head_num = model_config.num_key_value_heads
else:
self.kv_head_num = self.head_num
if config.kv_cache_dtype is None:
self.kv_cache_dtype = config.dtype
else:
self.kv_cache_dtype = config.kv_cache_dtype
assert (
self.kv_head_num % self.tp_size == 0
), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}"
self.kv_head_num //= self.tp_size
self.beam_width = config.beam_width
self.max_batch_size = config.max_batch_size
self.max_input_length = config.max_input_len
self.max_output_length = config.max_output_len
# Cache block settings
self.block_size = config.block_size
# NOTE: `num_blocks` is not prompted, but evaluated from the maximum input/output length, and the maximum batch size
if config.enable_streamingllm:
self.max_blocks_per_sequence = (
config.start_token_size + config.generated_token_size + self.block_size - 1
) // self.block_size + 1
else:
self.max_blocks_per_sequence = (
self.max_input_length + self.max_output_length + self.block_size - 1
) // self.block_size
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width
# Logical cache blocks allocation
self._available_blocks = self.num_blocks
self._cache_blocks = tuple(self._init_logical_caches())
# block availablity state 0->allocated, 1->free
self._block_states = torch.ones((self.num_blocks,), dtype=torch.bool)
self._block_states_cum = torch.zeros(size=(self.num_blocks + 1,), dtype=torch.int64)
self._block_finder = torch.zeros((self.num_blocks,), dtype=torch.int64)
def get_physical_cache_shape(self) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
# Physical cache allocation
if self.config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=self.config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
kalloc_shape = alloc_shape
valloc_shape = alloc_shape
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
return kalloc_shape, valloc_shape
def get_kv_cache(self):
"""Get k_cache and v_cache"""
return NotImplementedError
def _init_logical_caches(self):
"""Initialize the logical cache blocks."""
blocks = []
for i in range(self.num_blocks):
cache_block = CacheBlock(i, self.block_size, self.elem_size_in_bytes, k_ptrs=None, v_ptrs=None)
blocks.append(cache_block)
return blocks