|
|
|
@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
for rank in self.rank_of_tables: |
|
|
|
|
self.embedding_dim_per_rank[rank] += embedding_dim |
|
|
|
|
|
|
|
|
|
def forward(self, |
|
|
|
|
indices: torch.Tensor, |
|
|
|
|
offsets: torch.Tensor = None, |
|
|
|
|
per_sample_weights=None, |
|
|
|
|
shape_hook=None, |
|
|
|
|
already_split_along_rank=True, |
|
|
|
|
cache_op=True): |
|
|
|
|
self.cache_op = True |
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
|
self, |
|
|
|
|
indices: torch.Tensor, |
|
|
|
|
offsets: torch.Tensor = None, |
|
|
|
|
per_sample_weights=None, |
|
|
|
|
shape_hook=None, |
|
|
|
|
already_split_along_rank=True, |
|
|
|
|
): |
|
|
|
|
if not already_split_along_rank: |
|
|
|
|
# not recommanded. it takes time. |
|
|
|
|
batch_size = (offsets.shape[0]) // self.global_tables_num |
|
|
|
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
# recommanded. |
|
|
|
|
batch_size = (offsets.shape[0]) // len(self.assigned_table_list) |
|
|
|
|
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights |
|
|
|
|
if cache_op: |
|
|
|
|
if self.cache_op: |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
indices = self.cache_weight_mgr.prepare_ids(local_indices) |
|
|
|
|
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, |
|
|
|
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
|
|
|
|
|
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) |
|
|
|
|
return local_indices, local_offsets, local_per_sample_weights |
|
|
|
|
|
|
|
|
|
def set_cache_op(self, cache_op: bool = True): |
|
|
|
|
self.cache_op = cache_op |
|
|
|
|
|
|
|
|
|
def print_comm_stats_(self): |
|
|
|
|
self.cache_weight_mgr.print_comm_stats() |
|
|
|
|
|
|
|
|
|