Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

179 lines
6.6 KiB

import random
import pytest
import torch
from transformers.models.llama import LlamaConfig
import colossalai
from colossalai.inference.config import InferenceConfig
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize(
"test_config",
[
{
"elem_size": 2,
"block_size": 4,
}
],
)
def test_logical_blocks(test_config):
block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"])
assert block.is_empty()
assert block.available_space == test_config["block_size"]
assert not block.has_ref()
block.add_ref()
assert block.ref_count == 1
assert block.has_ref()
block.remove_ref()
assert block.ref_count == 0
block.allocate(1)
assert block.allocated_size == 1
block.allocate(test_config["block_size"] - 1)
assert block.available_space < 1
@parameterize(
"test_config",
[
{
"hidden_size": 512,
"num_attention_heads": 16,
"num_layers": 2,
"block_size": 8,
"max_batch_size": 10,
"max_input_len": 32,
"max_output_len": 32,
"dtype": torch.float32,
"beam_width": 1,
"tp_size": 1,
},
{
"hidden_size": 128,
"num_attention_heads": 4,
"num_layers": 3,
"block_size": 4,
"max_batch_size": 4,
"max_input_len": 64,
"max_output_len": 32,
"dtype": torch.float16,
"beam_width": 3,
"tp_size": 1,
},
],
)
def check_cache_manager(test_config):
disable_existing_loggers()
assert test_config["max_batch_size"] > 1
hidden_size = test_config.pop("hidden_size")
num_layers = test_config.pop("num_layers")
num_attention_heads = test_config.pop("num_attention_heads")
head_size = hidden_size // num_attention_heads
block_size = test_config["block_size"]
max_batch_size = test_config["max_batch_size"]
max_input_length = test_config["max_input_len"]
max_output_length = test_config["max_output_len"]
inference_config = InferenceConfig(**test_config)
model_config = LlamaConfig(
hidden_size=hidden_size,
num_hidden_layers=num_layers,
num_attention_heads=num_attention_heads,
)
cache_manager = KVCacheManager(inference_config, model_config)
num_blocks = cache_manager.total_num_blocks
assert num_blocks > 0
assert len(cache_manager._cache_blocks) == num_blocks
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
assert len(key_caches) == num_layers
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
assert key_caches[0].shape == expected_kv_shape
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
expected_kv_block_shape = expected_kv_shape[1:]
assert k_cache_block0.shape == expected_kv_block_shape
assert v_cache_block0.shape == expected_kv_block_shape
max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence()
block_tables = torch.tensor(
[[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32
)
context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)]
cnt_blocks_used = 0
# Mock Prefill
for req_i in range(max_batch_size):
cur_seq_len = context_lengths[req_i]
cur_block_table = block_tables[req_i]
cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len)
last_allocated_idx = (cur_seq_len - 1) // block_size
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used
# Mock Decoding
for req_i in range(max_batch_size):
context_length = context_lengths[req_i]
cur_output_length = random.randint(1, max_output_length)
cur_block_table = block_tables[req_i]
for _ in range(cur_output_length):
cache_manager.allocate_token_from_block_table(cur_block_table, context_length)
context_length += 1
context_length -= 1
last_allocated_idx = context_length // block_size
space_allocated_on_last_block = context_length % block_size + 1
assert space_allocated_on_last_block > 0
block_id = cur_block_table[last_allocated_idx]
block: CacheBlock = cache_manager._cache_blocks[block_id]
assert block.allocated_size == space_allocated_on_last_block
# Randomly select a request and clear its cache
req_i = random.randint(0, max_batch_size - 1)
context_length = context_lengths[req_i]
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item()
prev_available_blocks = cache_manager.num_available_blocks
cache_manager.free_block_table(block_tables[req_i])
assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0)
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0)
elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size()
expected_stride = block_size * num_attention_heads * head_size * elem_size
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride
cache_manager.clear_all()
assert cache_manager.num_available_blocks == num_blocks
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
# Mock batch operations (Prefill/Decoding updates)
context_lengths = torch.tensor([max_input_length, max_input_length - 1])
block_tables = torch.tensor(
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32
)
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths)
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths)
cache_manager.free_block_tables(block_tables)
for cache_block in cache_manager._cache_blocks:
assert cache_block.available_space == block_size
def run_dist(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
check_cache_manager()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_cache_manager():
spawn(run_dist, 1)
if __name__ == "__main__":
test_logical_blocks()
test_cache_manager()