|
|
|
@ -27,10 +27,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False.
|
|
|
|
|
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32.
|
|
|
|
|
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu.
|
|
|
|
|
cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0.
|
|
|
|
|
cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row
|
|
|
|
|
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None.
|
|
|
|
|
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7.
|
|
|
|
|
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0.
|
|
|
|
|
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0.
|
|
|
|
|
pin_weight (bool, optional): pin the cpu weight. Defaults to False.
|
|
|
|
|
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET.
|
|
|
|
|
"""
|
|
|
|
@ -48,7 +48,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
include_last_offset: bool = False,
|
|
|
|
|
dtype: Optional[torch.dtype] = None,
|
|
|
|
|
device: Optional[torch.device] = None,
|
|
|
|
|
cuda_row_num: int = 0,
|
|
|
|
|
cache_ratio: float = 0.01,
|
|
|
|
|
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None,
|
|
|
|
|
warmup_ratio: float = 0.7,
|
|
|
|
|
buffer_size: int = 0,
|
|
|
|
@ -57,10 +57,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
|
|
|
|
scale_grad_by_freq, sparse, mode, include_last_offset)
|
|
|
|
|
|
|
|
|
|
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0"
|
|
|
|
|
self.evict_strategy = evict_strategy
|
|
|
|
|
if _weight is None:
|
|
|
|
|
_weight = self._weight_alloc(dtype, device)
|
|
|
|
|
|
|
|
|
|
cuda_row_num = int(num_embeddings * cache_ratio)
|
|
|
|
|
# configure weight & cache
|
|
|
|
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
|
|
|
|
|
|
|
|
|