[FAW] LFU initialize with dataset freq (#1513)

pull/1514/head
Jiarui Fang 2 years ago committed by GitHub
parent 1b8fee8e9c
commit 9feee6d06b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
DATASET = 2
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num,
device=torch.cuda.current_device(),
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
Note:
If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics.
Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
with Timer() as timer:
# extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda()
preload_cuda_row_idxs = preload_row_ids.cuda()
if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0,
src_index=preload_row_ids,
tgt_index=preload_slot_ids,
tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows)
# update auxiliary info
slot_offsets = preload_slot_ids
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
slot_offsets = preload_cuda_row_idxs
self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.index_fill_(0,preload_slot_ids,0)
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
if ids_freq_mapping is None:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
else:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
self._cuda_available_row_num -= preload_row_num
print(f'Cache warmup finished cost {timer.elapsed} sec.')

Loading…
Cancel
Save