[embedding] cache_embedding small improvement (#1564)

pull/1574/head
CsRic 2022-09-08 16:41:19 +08:00 committed by GitHub
parent 10dd8226b1
commit a389ac4ec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 13 deletions

View File

@ -178,7 +178,7 @@ class CachedParamMgr(torch.nn.Module):
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
@ -304,7 +304,8 @@ class CachedParamMgr(torch.nn.Module):
self.evict_backlist = cpu_row_idxs
with record_function("(pre-id) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(
cpu_row_idxs, self.cached_idx_map, assume_unique=True, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
@ -345,7 +346,7 @@ class CachedParamMgr(torch.nn.Module):
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.

View File

@ -75,7 +75,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1):
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets,
self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
@ -124,6 +123,6 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
def print_comm_stats_(self):
self.cache_weight_mgr.print_comm_stats()
def element_size(self):
return self.weight.element_size()
return self.weight.element_size()

View File

@ -87,6 +87,7 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
local_per_sample_weights_list: List(torch.Tensor) = []
offset_pre_end = 0 # local_offsets trick
for i, handle_table in enumerate(self.assigned_table_list):
indices_start_position = offsets[batch_size * handle_table]
if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]):
@ -94,6 +95,28 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
indices_end_position = indices.shape[0]
else:
indices_end_position = offsets[batch_size * (handle_table + 1)]
# alternative approach: reduce malloc
'''
# 1. local_indices_list:
local_indices = indices.narrow(0, indices_start_position, indices_end_position - indices_start_position)
torch.sub(local_indices, self.idx_offset_list[i], out=local_indices)
local_indices_list.append(local_indices)
# 2. local_offsets_list:
if i + 1 == len(self.assigned_table_list):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size + 1)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
local_offsets_list.append(local_offsets)
else:
temp_holder = offsets[batch_size * handle_table].item()
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size)
torch.add(local_offsets, offset_pre_end - offsets[batch_size * handle_table], out=local_offsets)
offset_pre_end = offsets[batch_size * (handle_table + 1)] + offset_pre_end - temp_holder
local_offsets_list.append(local_offsets)
'''
# 1. local_indices_list:
local_indices_list.append(
indices.narrow(0, indices_start_position,
@ -103,21 +126,20 @@ class ParallelFreqAwareEmbeddingBagTablewise(FreqAwareEmbeddingBag):
# till-the-end special case
if not self.include_last_offset:
local_offsets = offsets.narrow(0, batch_size * handle_table,
batch_size).add(offset_pre_end - offsets[batch_size *
(handle_table)])
batch_size).add(offset_pre_end - offsets[batch_size
* (handle_table)])
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets_list.append(local_offsets)
else:
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size +
1).add(offset_pre_end - offsets[batch_size * (handle_table)])
local_offsets = offsets.narrow(0, batch_size * handle_table, batch_size
+ 1).add(offset_pre_end - offsets[batch_size * (handle_table)])
offset_pre_end = local_offsets[-1]
local_offsets_list.append(local_offsets[:-1])
# 3. local_per_sample_weights_list:
if per_sample_weights != None:
local_per_sample_weights_list.append(per_sample_weights[indices_start_position:indices_end_position])
local_indices = torch.cat(local_indices_list, 0)
local_offsets = torch.cat(local_offsets_list, 0)
local_per_sample_weights = None