[embeddings] cache option (#1635)

pull/1660/head
Jiarui Fang 2022-09-23 16:40:18 +08:00 committed by GitHub
parent a088022efc
commit e57df80325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 14 deletions

View File

@ -97,12 +97,13 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy=self.evict_strategy) evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None): def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
with torch.no_grad(): if cache_op:
reorder_ids = self.cache_weight_mgr.prepare_ids(input) with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, embeddings = F.embedding_bag(input.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx) per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None: if shape_hook is not None:
embeddings = shape_hook(embeddings) embeddings = shape_hook(embeddings)

View File

@ -72,11 +72,19 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr=ComputePattern.TP1D) compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): def forward(self,
with torch.no_grad(): indices,
reorder_ids = self.cache_weight_mgr.prepare_ids(indices) offsets=None,
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, per_sample_weights=None,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, shape_hook=None,
scatter_dim=0,
gather_dim=-1,
cache_op: bool = True):
if cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx) per_sample_weights, self.include_last_offset, self.padding_idx)
if shape_hook is not None: if shape_hook is not None:
output_shard = shape_hook(output_shard) output_shard = shape_hook(output_shard)

View File

@ -86,7 +86,8 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
offsets: torch.Tensor = None, offsets: torch.Tensor = None,
per_sample_weights=None, per_sample_weights=None,
shape_hook=None, shape_hook=None,
already_split_along_rank=True): already_split_along_rank=True,
cache_op=True):
if not already_split_along_rank: if not already_split_along_rank:
# not recommanded. it takes time. # not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num batch_size = (offsets.shape[0]) // self.global_tables_num
@ -96,9 +97,10 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded. # recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list) batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
with torch.no_grad(): if cache_op:
reorder_ids = self.cache_weight_mgr.prepare_ids(local_indices) with torch.no_grad():
local_output = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
local_per_sample_weights, self.include_last_offset, self.padding_idx) local_per_sample_weights, self.include_last_offset, self.padding_idx)
local_output = torch.cat(local_output.split(batch_size), 1) local_output = torch.cat(local_output.split(batch_size), 1)