[FAW] LFU initialize with dataset freq (#1513)

pull/1514/head
Jiarui Fang 2022-08-29 12:52:53 +08:00 committed by GitHub
parent 1b8fee8e9c
commit 9feee6d06b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 23 deletions

View File

@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
DATASET = 2 DATASET = 2
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
""" """
Manage Embedding Weights on CPU and CUDA memory uses a software cache. Manage Embedding Weights on CPU and CUDA memory uses a software cache.
@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
# cache_row_idx -> frequency, freq of the cache rows. # cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache. # classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter", self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num, torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize), dtype=torch.long).fill_(sys.maxsize),
persistent=False) persistent=False)
@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
reorder the weight according to ids' frequency in dataset before training. reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase. Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function. Note:
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics.
Args: Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
with Timer() as timer: with Timer() as timer:
# extract rows from cpu weight # extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num) preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda() preload_cuda_row_idxs = preload_row_ids.cuda()
if self.buffer_size > 0: if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0, self.limit_buff_index_copyer.index_copy(0,
src_index=preload_row_ids, src_index=preload_row_ids,
tgt_index=preload_slot_ids, tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows) self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows)
# update auxiliary info # update auxiliary info
slot_offsets = preload_slot_ids slot_offsets = preload_cuda_row_idxs
self.cached_idx_map[preload_slot_ids] = preload_slot_ids self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.index_fill_(0,preload_slot_ids,0) # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
self.inverted_cached_idx[preload_slot_ids] = slot_offsets if ids_freq_mapping is None:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
else:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
self._cuda_available_row_num -= preload_row_num self._cuda_available_row_num -= preload_row_num
print(f'Cache warmup finished cost {timer.elapsed} sec.') print(f'Cache warmup finished cost {timer.elapsed} sec.')