mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
3.0 KiB
82 lines
3.0 KiB
import torch |
|
|
|
from colossalai.context.singleton_meta import SingletonMeta |
|
from colossalai.utils import get_current_device |
|
|
|
|
|
class FDIntermTensors(metaclass=SingletonMeta): |
|
"""Singleton class to hold tensors used for storing intermediate values in flash-decoding. |
|
For now, it holds intermediate output and logsumexp (which will be used in reduction step along kv) |
|
""" |
|
|
|
def __init__(self): |
|
self._tensors_initialized = False |
|
|
|
def _reset(self): |
|
self._tensors_initialized = False |
|
del self._mid_output |
|
del self._mid_output_lse |
|
del self._exp_sums |
|
del self._max_logits |
|
|
|
@property |
|
def is_initialized(self): |
|
return self._tensors_initialized |
|
|
|
@property |
|
def mid_output(self): |
|
assert self.is_initialized, "Intermediate tensors not initialized yet" |
|
return self._mid_output |
|
|
|
@property |
|
def mid_output_lse(self): |
|
assert self.is_initialized, "Intermediate tensors not initialized yet" |
|
return self._mid_output_lse |
|
|
|
@property |
|
def exp_sums(self): |
|
assert self.is_initialized, "Intermediate tensors not initialized yet" |
|
return self._exp_sums |
|
|
|
@property |
|
def max_logits(self): |
|
assert self.is_initialized, "Intermediate tensors not initialized yet" |
|
return self._max_logits |
|
|
|
def initialize( |
|
self, |
|
max_batch_size: int, |
|
num_attn_heads: int, |
|
kv_max_split_num: int, |
|
head_dim: int, |
|
dtype: torch.dtype = torch.float32, |
|
device: torch.device = get_current_device(), |
|
) -> None: |
|
"""Initialize tensors. |
|
|
|
Args: |
|
max_batch_size (int): The maximum batch size over all the model forward. |
|
This could be greater than the batch size in attention forward func when using dynamic batch size. |
|
num_attn_heads (int)): Number of attention heads. |
|
kv_max_split_num (int): The maximum number of blocks splitted on kv in flash-decoding algorithm. |
|
**The maximum length/size of blocks splitted on kv should be the kv cache block size.** |
|
head_dim (int): Head dimension. |
|
dtype (torch.dtype, optional): Data type to be assigned to intermediate tensors. |
|
device (torch.device, optional): Device used to initialize intermediate tensors. |
|
""" |
|
assert not self.is_initialized, "Intermediate tensors used for Flash-Decoding have been initialized." |
|
|
|
self._mid_output = torch.empty( |
|
size=(max_batch_size, num_attn_heads, kv_max_split_num, head_dim), dtype=dtype, device=device |
|
) |
|
self._mid_output_lse = torch.empty( |
|
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device |
|
) |
|
self._exp_sums = torch.empty( |
|
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device |
|
) |
|
self._max_logits = torch.empty( |
|
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device |
|
) |
|
|
|
self._tensors_initialized = True
|
|
|