|
|
|
@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
offsets: torch.Tensor = None, |
|
|
|
|
per_sample_weights=None, |
|
|
|
|
shape_hook=None, |
|
|
|
|
already_split_along_rank=True): |
|
|
|
|
already_split_along_rank=True, |
|
|
|
|
cache_op=True): |
|
|
|
|
if not already_split_along_rank: |
|
|
|
|
# not recommanded. it takes time. |
|
|
|
|
batch_size = (offsets.shape[0]) // self.global_tables_num |
|
|
|
@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
# recommanded. |
|
|
|
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list) |
|
|
|
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) |
|
|
|
|
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, |
|
|
|
|
if cache_op: |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
indices = self.cache_weight_mgr.prepare_ids(local_indices) |
|
|
|
|
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, |
|
|
|
|
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, |
|
|
|
|
local_per_sample_weights, self.include_last_offset, self.padding_idx) |
|
|
|
|
local_output = torch.cat(local_output.split(batch_size), 1) |
|
|
|
|