|
|
|
@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
cuda_row_num = int(num_embeddings * cache_ratio)
|
|
|
|
|
# configure weight & cache
|
|
|
|
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
|
|
|
|
self.cache_op = True
|
|
|
|
|
|
|
|
|
|
def _weight_alloc(self, dtype, device):
|
|
|
|
|
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
|
|
|
@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
evict_strategy=self.evict_strategy)
|
|
|
|
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
|
|
|
|
|
|
|
|
|
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
|
|
|
|
|
if cache_op:
|
|
|
|
|
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
|
|
|
|
|
if self.cache_op:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
input = self.cache_weight_mgr.prepare_ids(input)
|
|
|
|
|
|
|
|
|
@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
|
|
|
yield self.cache_weight_mgr.cuda_cached_weight
|
|
|
|
|
|
|
|
|
|
def set_cache_op(self, cache_op: bool = True):
|
|
|
|
|
self.cache_op = cache_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
############################# Perf Log ###################################
|
|
|
|
|
|
|
|
|
|