mirror of https://github.com/hpcaitech/ColossalAI
[feature] add KV cache manager for llama & bloom inference (#4495)
* add kv cache memory manager * add stateinfo during inference * format * format * rename file * add kv cache test * revise on BatchInferState * file dir changepull/4523/head^2
parent
64110b12c0
commit
2226c6836c
|
@ -0,0 +1,4 @@
|
||||||
|
from .batch_infer_state import BatchInferState
|
||||||
|
from .kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
__all__ = ['BatchInferState', 'MemoryManager']
|
|
@ -0,0 +1,52 @@
|
||||||
|
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .kvcache_manager import MemoryManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatchInferState:
|
||||||
|
r"""
|
||||||
|
Information to be passed and used for a batch of inputs during
|
||||||
|
a single model forward
|
||||||
|
"""
|
||||||
|
batch_size: int
|
||||||
|
max_len_in_batch: int
|
||||||
|
|
||||||
|
cache_manager: MemoryManager = None
|
||||||
|
|
||||||
|
block_loc: torch.Tensor = None
|
||||||
|
start_loc: torch.Tensor = None
|
||||||
|
seq_len: torch.Tensor = None
|
||||||
|
|
||||||
|
is_context_stage: bool = False
|
||||||
|
context_mem_index: torch.Tensor = None
|
||||||
|
decode_is_contiguous: bool = None
|
||||||
|
decode_mem_start: int = None
|
||||||
|
decode_mem_end: int = None
|
||||||
|
decode_mem_index: torch.Tensor = None
|
||||||
|
decode_layer_id: int = None
|
||||||
|
|
||||||
|
device: torch.device = torch.device('cuda')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_token_num(self):
|
||||||
|
return self.batch_size * self.max_len_in_batch
|
||||||
|
|
||||||
|
def set_cache_manager(self, manager: MemoryManager):
|
||||||
|
self.cache_manager = manager
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int,
|
||||||
|
alloc_mem_index: torch.Tensor):
|
||||||
|
""" in-place update block loc mapping based on the sequence length of the inputs in current bath"""
|
||||||
|
start_index = 0
|
||||||
|
seq_len_numpy = seq_len.cpu().numpy()
|
||||||
|
for i, cur_seq_len in enumerate(seq_len_numpy):
|
||||||
|
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index +
|
||||||
|
cur_seq_len]
|
||||||
|
start_index += cur_seq_len
|
||||||
|
return
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Adapted from lightllm/common/mem_manager.py
|
||||||
|
# of the ModelTC/lightllm GitHub repository
|
||||||
|
# https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
|
||||||
|
#
|
||||||
|
# Copyright 2023 ModelTC Team
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryManager:
|
||||||
|
r"""
|
||||||
|
Manage token block indexes and allocate physical memory for key and value cache
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: maximum token number used as the size of key and value buffer
|
||||||
|
dtype: data type of cached key and value
|
||||||
|
head_num: number of heads the memory manager is responsible for
|
||||||
|
head_dim: embedded size per head
|
||||||
|
layer_num: the number of layers in the model
|
||||||
|
device: device used to store the key and value cache
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
head_num: int,
|
||||||
|
head_dim: int,
|
||||||
|
layer_num: int,
|
||||||
|
device: torch.device = torch.device('cuda')):
|
||||||
|
self.logger = get_dist_logger(__name__)
|
||||||
|
self.available_size = size
|
||||||
|
self.past_key_values_length = 0
|
||||||
|
self._init_mem_states(size, device)
|
||||||
|
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
|
||||||
|
|
||||||
|
def _init_mem_states(self, size, device):
|
||||||
|
""" Initialize tensors used to manage memory states """
|
||||||
|
self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
|
||||||
|
self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
|
||||||
|
self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
|
||||||
|
""" Initialize key buffer and value buffer on specified device """
|
||||||
|
self.key_buffer = [
|
||||||
|
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
self.value_buffer = [
|
||||||
|
torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def alloc(self, required_size):
|
||||||
|
""" allocate space of required_size by providing indexes representing available physical spaces """
|
||||||
|
if required_size > self.available_size:
|
||||||
|
self.logger.warning(f"No enough cache: required_size {required_size} "
|
||||||
|
f"left_size {self.available_size}")
|
||||||
|
return None
|
||||||
|
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
||||||
|
select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
|
||||||
|
select_index = self.indexes[select_index]
|
||||||
|
self.mem_state[select_index] = 0
|
||||||
|
self.available_size -= len(select_index)
|
||||||
|
return select_index
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def alloc_contiguous(self, required_size):
|
||||||
|
""" allocate contiguous space of required_size """
|
||||||
|
if required_size > self.available_size:
|
||||||
|
self.logger.warning(f"No enough cache: required_size {required_size} "
|
||||||
|
f"left_size {self.available_size}")
|
||||||
|
return None
|
||||||
|
torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
|
||||||
|
sum_size = len(self.mem_cum_sum)
|
||||||
|
loc_sums = self.mem_cum_sum[required_size - 1:] - self.mem_cum_sum[0:sum_size - required_size +
|
||||||
|
1] + self.mem_state[0:sum_size -
|
||||||
|
required_size + 1]
|
||||||
|
can_used_loc = self.indexes[0:sum_size - required_size + 1][loc_sums == required_size]
|
||||||
|
if can_used_loc.shape[0] == 0:
|
||||||
|
self.logger.info(f"No enough contiguous cache: required_size {required_size} "
|
||||||
|
f"left_size {self.available_size}")
|
||||||
|
return None
|
||||||
|
start_loc = can_used_loc[0]
|
||||||
|
select_index = self.indexes[start_loc:start_loc + required_size]
|
||||||
|
self.mem_state[select_index] = 0
|
||||||
|
self.available_size -= len(select_index)
|
||||||
|
start = start_loc.item()
|
||||||
|
end = start + required_size
|
||||||
|
return select_index, start, end
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def free(self, free_index):
|
||||||
|
""" free memory by updating memory states based on given indexes """
|
||||||
|
self.available_size += free_index.shape[0]
|
||||||
|
self.mem_state[free_index] = 1
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def free_all(self):
|
||||||
|
""" free all memory by updating memory states """
|
||||||
|
self.available_size = len(self.mem_state)
|
||||||
|
self.mem_state[:] = 1
|
||||||
|
self.past_key_values_length = 0
|
||||||
|
self.logger.info("freed all space of memory manager")
|
|
@ -0,0 +1,60 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.inference import MemoryManager
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
INPUT_LEN = 16
|
||||||
|
OUTPUT_LEN = 8
|
||||||
|
LAYER_NUM = 4
|
||||||
|
HEAD_NUM = 32
|
||||||
|
HEAD_DIM = 128
|
||||||
|
|
||||||
|
|
||||||
|
def create_cache_manager(rank, world_size, port, batch_size, input_len, output_len, layer_num, head_num, head_dim):
|
||||||
|
os.environ['RANK'] = str(rank)
|
||||||
|
os.environ['LOCAL_RANK'] = str(rank)
|
||||||
|
os.environ['WORLD_SIZE'] = str(world_size)
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ['MASTER_PORT'] = str(port)
|
||||||
|
disable_existing_loggers()
|
||||||
|
|
||||||
|
size = batch_size * (input_len + output_len)
|
||||||
|
kvcache_manager = MemoryManager(size, torch.float16, head_num // world_size, head_dim, layer_num, rank)
|
||||||
|
key_buffers = kvcache_manager.key_buffer
|
||||||
|
value_buffers = kvcache_manager.value_buffer
|
||||||
|
assert len(key_buffers) == len(value_buffers) == layer_num
|
||||||
|
assert key_buffers[0].shape == value_buffers[0].shape
|
||||||
|
# required size exceeds the maximum allocated size
|
||||||
|
invalid_locs = kvcache_manager.alloc_contiguous(size + 1)
|
||||||
|
assert invalid_locs is None
|
||||||
|
# for prefill stage, allocation via alloc and alloc_contiguous should be the same
|
||||||
|
total_token_prefill = batch_size * input_len
|
||||||
|
prefill_locs = kvcache_manager.alloc(total_token_prefill)
|
||||||
|
kvcache_manager.free_all()
|
||||||
|
prefill_locs_contiguous = kvcache_manager.alloc_contiguous(total_token_prefill)[0]
|
||||||
|
assert torch.equal(prefill_locs, prefill_locs_contiguous)
|
||||||
|
assert torch.sum(kvcache_manager.mem_state).item() == size - total_token_prefill
|
||||||
|
kvcache_manager.alloc_contiguous(batch_size)
|
||||||
|
assert torch.all(kvcache_manager.mem_state[:total_token_prefill + batch_size] == False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_cache_manager_dist():
|
||||||
|
spawn(create_cache_manager,
|
||||||
|
4,
|
||||||
|
batch_size=BATCH_SIZE,
|
||||||
|
input_len=INPUT_LEN,
|
||||||
|
output_len=OUTPUT_LEN,
|
||||||
|
layer_num=LAYER_NUM,
|
||||||
|
head_num=HEAD_NUM,
|
||||||
|
head_dim=HEAD_DIM)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_cache_manager_dist()
|
Loading…
Reference in New Issue