You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/inference/flash_decoding_utils.py

64 lines
2.4 KiB

import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
class FDIntermTensors(metaclass=SingletonMeta):
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
"""
def __init__(self):
self._tensors_initialized = False
def _reset(self):
self._tensors_initialized = False
del self._mid_output
del self._mid_output_lse
@property
def is_initialized(self):
return self._tensors_initialized
@property
def mid_output(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output
@property
def mid_output_lse(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse
def initialize(
self,
max_batch_size: int,
num_attn_heads: int,
kv_max_split_num: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device = get_current_device(),
) -> None:
"""Initialize tensors.
Args:
max_batch_size (int): The maximum batch size over all the model forward.
This could be greater than the batch size in attention forward func when using dynamic batch size.
num_attn_heads (int)): Number of attention heads.
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
head_dim (int): Head dimension.
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
device (torch.device, optional): Device used to initialize intermediate tensors.
"""
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
self._mid_output = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
)
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True