[embeddings] cache option (#1635)

pull/1660/head
Jiarui Fang 2022-09-23 16:40:18 +08:00 committed by GitHub
parent a088022efc
commit e57df80325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 14 deletions

View File

@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(input)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
if cache_op:
with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
embeddings = shape_hook(embeddings)

View File

@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
def forward(self,
indices,
offsets=None,
per_sample_weights=None,
shape_hook=None,
scatter_dim=0,
gather_dim=-1,
cache_op: bool = True):
if cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None:
output_shard = shape_hook(output_shard)

View File

@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True):
already_split_along_rank=True,
cache_op=True):
if not already_split_along_rank:
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num
@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
if cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1)