mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
6.6 KiB
179 lines
6.6 KiB
import random |
|
|
|
import pytest |
|
import torch |
|
from transformers.models.llama import LlamaConfig |
|
|
|
import colossalai |
|
from colossalai.inference.config import InferenceConfig |
|
from colossalai.inference.kv_cache import CacheBlock, KVCacheManager |
|
from colossalai.logging import disable_existing_loggers |
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn |
|
|
|
|
|
@parameterize( |
|
"test_config", |
|
[ |
|
{ |
|
"elem_size": 2, |
|
"block_size": 4, |
|
} |
|
], |
|
) |
|
def test_logical_blocks(test_config): |
|
block = CacheBlock(block_id=0, block_size=test_config["block_size"], elem_size=test_config["elem_size"]) |
|
|
|
assert block.is_empty() |
|
assert block.available_space == test_config["block_size"] |
|
assert not block.has_ref() |
|
block.add_ref() |
|
assert block.ref_count == 1 |
|
assert block.has_ref() |
|
block.remove_ref() |
|
assert block.ref_count == 0 |
|
block.allocate(1) |
|
assert block.allocated_size == 1 |
|
block.allocate(test_config["block_size"] - 1) |
|
assert block.available_space < 1 |
|
|
|
|
|
@parameterize( |
|
"test_config", |
|
[ |
|
{ |
|
"hidden_size": 512, |
|
"num_attention_heads": 16, |
|
"num_layers": 2, |
|
"block_size": 8, |
|
"max_batch_size": 10, |
|
"max_input_len": 32, |
|
"max_output_len": 32, |
|
"dtype": torch.float32, |
|
"beam_width": 1, |
|
"tp_size": 1, |
|
}, |
|
{ |
|
"hidden_size": 128, |
|
"num_attention_heads": 4, |
|
"num_layers": 3, |
|
"block_size": 4, |
|
"max_batch_size": 4, |
|
"max_input_len": 64, |
|
"max_output_len": 32, |
|
"dtype": torch.float16, |
|
"beam_width": 3, |
|
"tp_size": 1, |
|
}, |
|
], |
|
) |
|
def check_cache_manager(test_config): |
|
disable_existing_loggers() |
|
|
|
assert test_config["max_batch_size"] > 1 |
|
|
|
hidden_size = test_config.pop("hidden_size") |
|
num_layers = test_config.pop("num_layers") |
|
num_attention_heads = test_config.pop("num_attention_heads") |
|
head_size = hidden_size // num_attention_heads |
|
block_size = test_config["block_size"] |
|
max_batch_size = test_config["max_batch_size"] |
|
max_input_length = test_config["max_input_len"] |
|
max_output_length = test_config["max_output_len"] |
|
|
|
inference_config = InferenceConfig(**test_config) |
|
model_config = LlamaConfig( |
|
hidden_size=hidden_size, |
|
num_hidden_layers=num_layers, |
|
num_attention_heads=num_attention_heads, |
|
) |
|
cache_manager = KVCacheManager(inference_config, model_config) |
|
|
|
num_blocks = cache_manager.total_num_blocks |
|
assert num_blocks > 0 |
|
assert len(cache_manager._cache_blocks) == num_blocks |
|
key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers |
|
assert len(key_caches) == num_layers |
|
expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size) |
|
assert key_caches[0].shape == expected_kv_shape |
|
k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) |
|
expected_kv_block_shape = expected_kv_shape[1:] |
|
assert k_cache_block0.shape == expected_kv_block_shape |
|
assert v_cache_block0.shape == expected_kv_block_shape |
|
|
|
max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence() |
|
block_tables = torch.tensor( |
|
[[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32 |
|
) |
|
context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)] |
|
cnt_blocks_used = 0 |
|
# Mock Prefill |
|
for req_i in range(max_batch_size): |
|
cur_seq_len = context_lengths[req_i] |
|
cur_block_table = block_tables[req_i] |
|
cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len) |
|
last_allocated_idx = (cur_seq_len - 1) // block_size |
|
assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0) |
|
cnt_blocks_used += torch.sum(cur_block_table >= 0).item() |
|
assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used |
|
|
|
# Mock Decoding |
|
for req_i in range(max_batch_size): |
|
context_length = context_lengths[req_i] |
|
cur_output_length = random.randint(1, max_output_length) |
|
cur_block_table = block_tables[req_i] |
|
for _ in range(cur_output_length): |
|
cache_manager.allocate_token_from_block_table(cur_block_table, context_length) |
|
context_length += 1 |
|
context_length -= 1 |
|
last_allocated_idx = context_length // block_size |
|
space_allocated_on_last_block = context_length % block_size + 1 |
|
assert space_allocated_on_last_block > 0 |
|
block_id = cur_block_table[last_allocated_idx] |
|
block: CacheBlock = cache_manager._cache_blocks[block_id] |
|
assert block.allocated_size == space_allocated_on_last_block |
|
|
|
# Randomly select a request and clear its cache |
|
req_i = random.randint(0, max_batch_size - 1) |
|
context_length = context_lengths[req_i] |
|
blocks_used_by_req = torch.sum(block_tables[req_i] >= 0).item() |
|
prev_available_blocks = cache_manager.num_available_blocks |
|
cache_manager.free_block_table(block_tables[req_i]) |
|
assert cache_manager.num_available_blocks == blocks_used_by_req + prev_available_blocks |
|
|
|
k_ptr_block0_layer0, _ = cache_manager.get_block_kv_ptrs(0, 0) |
|
k_ptr_block1_layer0, _ = cache_manager.get_block_kv_ptrs(1, 0) |
|
elem_size = torch.tensor([], dtype=test_config["dtype"]).element_size() |
|
expected_stride = block_size * num_attention_heads * head_size * elem_size |
|
assert k_ptr_block1_layer0 - k_ptr_block0_layer0 == expected_stride |
|
cache_manager.clear_all() |
|
assert cache_manager.num_available_blocks == num_blocks |
|
|
|
for cache_block in cache_manager._cache_blocks: |
|
assert cache_block.available_space == block_size |
|
|
|
# Mock batch operations (Prefill/Decoding updates) |
|
context_lengths = torch.tensor([max_input_length, max_input_length - 1]) |
|
block_tables = torch.tensor( |
|
[[-1 for _ in range(cache_manager.max_blocks_per_sequence)] for _ in range(2)], dtype=torch.int32 |
|
) |
|
cache_manager.allocate_context_from_block_tables(block_tables, context_lengths) |
|
cache_manager.allocate_tokens_from_block_tables(block_tables, context_lengths) |
|
cache_manager.free_block_tables(block_tables) |
|
for cache_block in cache_manager._cache_blocks: |
|
assert cache_block.available_space == block_size |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") |
|
check_cache_manager() |
|
|
|
|
|
@pytest.mark.dist |
|
@rerun_if_address_is_in_use() |
|
def test_cache_manager(): |
|
spawn(run_dist, 1) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_logical_blocks() |
|
test_cache_manager()
|
|
|