ColossalAI/colossalai/inference/flash_decoding_utils.py

83 lines
3.0 KiB
Python

import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils import get_current_device
class FDIntermTensors(metaclass=SingletonMeta):
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding.
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv)
"""
def __init__(self):
self._tensors_initialized = False
def _reset(self):
self._tensors_initialized = False
del self._mid_output
del self._mid_output_lse
del self._exp_sums
del self._max_logits
@property
def is_initialized(self):
return self._tensors_initialized
@property
def mid_output(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output
@property
def mid_output_lse(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse
@property
def exp_sums(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._exp_sums
@property
def max_logits(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._max_logits
def initialize(
self,
max_batch_size: int,
num_attn_heads: int,
kv_max_split_num: int,
head_dim: int,
dtype: torch.dtype = torch.float32,
device: torch.device = get_current_device(),
) -> None:
"""Initialize tensors.
Args:
max_batch_size (int): The maximum batch size over all the model forward.
This could be greater than the batch size in attention forward func when using dynamic batch size.
num_attn_heads (int)): Number of attention heads.
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm.
**The maximum length/size of blocks splitted on kv should be the kv cache block size.**
head_dim (int): Head dimension.
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors.
device (torch.device, optional): Device used to initialize intermediate tensors.
"""
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized."
self._mid_output = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device
)
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._exp_sums = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._max_logits = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._tensors_initialized = True