[embedding] isolate cache_op from forward (#1645)

Co-authored-by: ric <mkkt_bkkt@mail.ustc.edu.cn>
pull/1648/head
CsRic 2022-09-26 11:18:59 +08:00 committed by GitHub
parent c5d39215f6
commit 0767f67a0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 19 deletions

View File

@ -64,6 +64,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
cuda_row_num = int(num_embeddings * cache_ratio) cuda_row_num = int(num_embeddings * cache_ratio)
# configure weight & cache # configure weight & cache
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
self.cache_op = True
def _weight_alloc(self, dtype, device): def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
@ -97,8 +98,8 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
evict_strategy=self.evict_strategy) evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None, cache_op=True): def forward(self, input, offsets=None, per_sample_weights=None, shape_hook=None):
if cache_op: if self.cache_op:
with torch.no_grad(): with torch.no_grad():
input = self.cache_weight_mgr.prepare_ids(input) input = self.cache_weight_mgr.prepare_ids(input)
@ -119,6 +120,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight yield self.cache_weight_mgr.cuda_cached_weight
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
############################# Perf Log ################################### ############################# Perf Log ###################################

View File

@ -60,6 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping, sparse, _weight, mode, include_last_offset, dtype, device, cache_ratio, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight, evict_strategy) warmup_ratio, buffer_size, pin_weight, evict_strategy)
self.cache_op = True
def _weight_alloc(self, dtype, device): def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
@ -72,15 +73,16 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
compute_attr=ComputePattern.TP1D) compute_attr=ComputePattern.TP1D)
return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec) return ColoTensor.from_torch_tensor(weight, spec=colo_tensor_spec)
def forward(self, def forward(
indices, self,
offsets=None, indices,
per_sample_weights=None, offsets=None,
shape_hook=None, per_sample_weights=None,
scatter_dim=0, shape_hook=None,
gather_dim=-1, scatter_dim=0,
cache_op: bool = True): gather_dim=-1,
if cache_op: ):
if self.cache_op:
with torch.no_grad(): with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(indices) indices = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, output_shard = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
@ -94,6 +96,9 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
gather_dim=gather_dim) gather_dim=gather_dim)
return output_full return output_full
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,

View File

@ -81,13 +81,16 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
for rank in self.rank_of_tables: for rank in self.rank_of_tables:
self.embedding_dim_per_rank[rank] += embedding_dim self.embedding_dim_per_rank[rank] += embedding_dim
def forward(self, self.cache_op = True
indices: torch.Tensor,
offsets: torch.Tensor = None, def forward(
per_sample_weights=None, self,
shape_hook=None, indices: torch.Tensor,
already_split_along_rank=True, offsets: torch.Tensor = None,
cache_op=True): per_sample_weights=None,
shape_hook=None,
already_split_along_rank=True,
):
if not already_split_along_rank: if not already_split_along_rank:
# not recommanded. it takes time. # not recommanded. it takes time.
batch_size = (offsets.shape[0]) // self.global_tables_num batch_size = (offsets.shape[0]) // self.global_tables_num
@ -97,7 +100,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# recommanded. # recommanded.
batch_size = (offsets.shape[0]) // len(self.assigned_table_list) batch_size = (offsets.shape[0]) // len(self.assigned_table_list)
local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights local_indices, local_offsets, local_per_sample_weights = indices, offsets, per_sample_weights
if cache_op: if self.cache_op:
with torch.no_grad(): with torch.no_grad():
indices = self.cache_weight_mgr.prepare_ids(local_indices) indices = self.cache_weight_mgr.prepare_ids(local_indices)
local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets, local_output = F.embedding_bag(indices.cuda(), self.cache_weight_mgr.cuda_cached_weight, local_offsets,
@ -185,6 +188,9 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0) local_per_sample_weights = torch.cat(local_per_sample_weights_list, 0)
return local_indices, local_offsets, local_per_sample_weights return local_indices, local_offsets, local_per_sample_weights
def set_cache_op(self, cache_op: bool = True):
self.cache_op = cache_op
def print_comm_stats_(self): def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats() self.cache_weight_mgr.print_comm_stats()