import torch from colossalai.context.singleton_meta import SingletonMeta from colossalai.utils import get_current_device class FDIntermTensors(metaclass=SingletonMeta): """Singleton class to hold tensors used for storing intermediate values in flash-decoding. For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv) """ def __init__(self): self._tensors_initialized = False def _reset(self): self._tensors_initialized = False del self._mid_output del self._mid_output_lse del self._exp_sums del self._max_logits @property def is_initialized(self): return self._tensors_initialized @property def mid_output(self): assert self.is_initialized, "Intermediate tensors not initialized yet" return self._mid_output @property def mid_output_lse(self): assert self.is_initialized, "Intermediate tensors not initialized yet" return self._mid_output_lse @property def exp_sums(self): assert self.is_initialized, "Intermediate tensors not initialized yet" return self._exp_sums @property def max_logits(self): assert self.is_initialized, "Intermediate tensors not initialized yet" return self._max_logits def initialize( self, max_batch_size: int, num_attn_heads: int, kv_max_split_num: int, head_dim: int, dtype: torch.dtype = torch.float32, device: torch.device = get_current_device(), ) -> None: """Initialize tensors. Args: max_batch_size (int): The maximum batch size over all the model forward. This could be greater than the batch size in attention forward func when using dynamic batch size. num_attn_heads (int)): Number of attention heads. kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm. **The maximum length/size of blocks splitted on kv should be the kv cache block size.** head_dim (int): Head dimension. dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors. device (torch.device, optional): Device used to initialize intermediate tensors. """ assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized." self._mid_output = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device ) self._mid_output_lse = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device ) self._exp_sums = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device ) self._max_logits = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device ) self._tensors_initialized = True