|
|
|
@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
for rank in self.rank_of_tables:
|
|
|
|
|
self.embedding_dim_per_rank[rank] += embedding_dim
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
offsets: torch.Tensor = None,
|
|
|
|
|
per_sample_weights=None,
|
|
|
|
|
shape_hook=None,
|
|
|
|
|
already_split_along_rank=True,
|
|
|
|
|
cache_op=True):
|
|
|
|
|
self.cache_op = True
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
offsets: torch.Tensor = None,
|
|
|
|
|
per_sample_weights=None,
|
|
|
|
|
shape_hook=None,
|
|
|
|
|
already_split_along_rank=True,
|
|
|
|
|
):
|
|
|
|
|
if not already_split_along_rank:
|
|
|
|
|
# not recommanded. it takes time.
|
|
|
|
|
batch_size = (offsets.shape[0]) // self.global_tables_num
|
|
|
|
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
# recommanded.
|
|
|
|
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
|
|
|
|
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
|
|
|
|
|
if cache_op:
|
|
|
|
|
if self.cache_op:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
indices = self.cache_weight_mgr.prepare_ids(local_indices)
|
|
|
|
|
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
|
|
|
|
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
|
|
|
|
|
return local_indices, local_offsets, local_per_sample_weights
|
|
|
|
|
|
|
|
|
|
def set_cache_op(self, cache_op: bool = True):
|
|
|
|
|
self.cache_op = cache_op
|
|
|
|
|
|
|
|
|
|
def print_comm_stats_(self):
|
|
|
|
|
self.cache_weight_mgr.print_comm_stats()
|
|
|
|
|
|
|
|
|
|