[embedding] isolate cache_op from forward (#1645)

Co-authored-by: ric <mkkt_bkkt@mail.ustc.edu.cn>
pull/1648/head
CsRic 2022-09-26 11:18:59 +08:00 committed by GitHub
parent c5d39215f6
commit 0767f67a0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 19 deletions

View File

@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num = int(num_embeddings * cache_ratio)
# configure weight & cache
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
self.cache_op = True
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True):
if cache_op:
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
if self.cache_op:
with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input)
@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
############################# Perf Log ###################################

View File

@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy)
self.cache_op = True
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(self,
indices,
offsets=None,
per_sample_weights=None,
shape_hook=None,
scatter_dim=0,
gather_dim=-1,
cache_op: bool = True):
if cache_op:
def forward(
self,
indices,
offsets=None,
per_sample_weights=None,
shape_hook=None,
scatter_dim=0,
gather_dim=-1,
):
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
gather_dim=gather_dim)
return output_full
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
@classmethod
def from_pretrained(
cls,

View File

@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim
def forward(self,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True,
cache_op=True):
self.cache_op = True
def forward(
self,
indices: torch.Tensor,
offsets: torch.Tensor = None,
per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True,
):
if not already_split_along_rank:
# not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
if cache_op:
if self.cache_op:
with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
return local_indices, local_offsets, local_per_sample_weights
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()