|
|
|
@ -27,10 +27,10 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. |
|
|
|
|
dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. |
|
|
|
|
device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. |
|
|
|
|
cuda_row_num (int, optional): the max number of embedding vector in cuda cache. Defaults to 0. |
|
|
|
|
cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row |
|
|
|
|
ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. |
|
|
|
|
warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. |
|
|
|
|
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, means do not use the buffer. Defaults to 0. |
|
|
|
|
buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. |
|
|
|
|
pin_weight (bool, optional): pin the cpu weight. Defaults to False. |
|
|
|
|
evict_strategy (EvictionStrategy, optional): evict strategy of the software cache. Defaults to EvictionStrategy.DATASET. |
|
|
|
|
""" |
|
|
|
@ -48,7 +48,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
include_last_offset: bool = False, |
|
|
|
|
dtype: Optional[torch.dtype] = None, |
|
|
|
|
device: Optional[torch.device] = None, |
|
|
|
|
cuda_row_num: int = 0, |
|
|
|
|
cache_ratio: float = 0.01, |
|
|
|
|
ids_freq_mapping: Optional[Union[List, torch.Tensor]] = None, |
|
|
|
|
warmup_ratio: float = 0.7, |
|
|
|
|
buffer_size: int = 0, |
|
|
|
@ -57,10 +57,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, |
|
|
|
|
scale_grad_by_freq, sparse, mode, include_last_offset) |
|
|
|
|
|
|
|
|
|
assert cache_ratio <= 1.0, f"cache ratio {cache_ratio} must less than 1.0" |
|
|
|
|
self.evict_strategy = evict_strategy |
|
|
|
|
if _weight is None: |
|
|
|
|
_weight = self._weight_alloc(dtype, device) |
|
|
|
|
|
|
|
|
|
cuda_row_num = int(num_embeddings * cache_ratio) |
|
|
|
|
# configure weight & cache |
|
|
|
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) |
|
|
|
|
|
|
|
|
|