[FAW] LFU initialize with dataset freq (#1513)

pull/1514/head
Jiarui Fang 2022-08-29 12:52:53 +08:00 committed by GitHub
parent 1b8fee8e9c
commit 9feee6d06b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 23 deletions

View File

@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
DATASET = 2 DATASET = 2
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
""" """
Manage Embedding Weights on CPU and CUDA memory uses a software cache. Manage Embedding Weights on CPU and CUDA memory uses a software cache.
@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
# cache_row_idx -> frequency, freq of the cache rows. # cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache. # classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter", self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num, torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize), dtype=torch.long).fill_(sys.maxsize),
persistent=False) persistent=False)
@ -82,14 +80,14 @@ class CachedParamMgr(torch.nn.Module):
""" """
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map # find the minimal evict_num freq entries in cached_idx_map
_,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False) _, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False)
return evict_gpu_row_idxs return evict_gpu_row_idxs
elif self._evict_strategy == EvictionStrategy.DATASET: elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction. # cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx. # The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset, # The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be. # and the higher its eviction priority will be.
_,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True) _, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
return evict_gpu_row_idxs return evict_gpu_row_idxs
else: else:
raise TypeError raise TypeError
@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
reorder the weight according to ids' frequency in dataset before training. reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase. Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function. Note:
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. If you would like to use the DATASET as the eviction strategy, you must call this function.
Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics.
Args: Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
with Timer() as timer: with Timer() as timer:
# extract rows from cpu weight # extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num) preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda() preload_cuda_row_idxs = preload_row_ids.cuda()
if self.buffer_size > 0: if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0, self.limit_buff_index_copyer.index_copy(0,
src_index=preload_row_ids, src_index=preload_row_ids,
tgt_index=preload_slot_ids, tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows) self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows)
# update auxiliary info # update auxiliary info
slot_offsets = preload_slot_ids slot_offsets = preload_cuda_row_idxs
self.cached_idx_map[preload_slot_ids] = preload_slot_ids self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter.index_fill_(0,preload_slot_ids,0) if self._evict_strategy == EvictionStrategy.LFU:
self.inverted_cached_idx[preload_slot_ids] = slot_offsets # if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
if ids_freq_mapping is None:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
else:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
self._cuda_available_row_num -= preload_row_num self._cuda_available_row_num -= preload_row_num
print(f'Cache warmup finished cost {timer.elapsed} sec.') print(f'Cache warmup finished cost {timer.elapsed} sec.')
@ -215,7 +224,7 @@ class CachedParamMgr(torch.nn.Module):
self.inverted_cached_idx.index_fill_(0, row_ids, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel() self._cuda_available_row_num += slots.numel()
if self._evict_strategy == EvictionStrategy.LFU : if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.fill_(sys.maxsize) self.freq_cnter.fill_(sys.maxsize)
assert self._cuda_available_row_num == self.cuda_row_num assert self._cuda_available_row_num == self.cuda_row_num
assert torch.all(self.inverted_cached_idx == -1).item() assert torch.all(self.inverted_cached_idx == -1).item()
@ -258,7 +267,7 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight. torch.Tensor: indices on the cuda_cached_weight.
""" """
with record_function("(zhg) get unique indices"): with record_function("(zhg) get unique indices"):
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True) cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
@ -283,10 +292,10 @@ class CachedParamMgr(torch.nn.Module):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
if self._evict_strategy == EvictionStrategy.LFU : if self._evict_strategy == EvictionStrategy.LFU:
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times) self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
return gpu_row_idxs return gpu_row_idxs
def _reset_comm_stats(self): def _reset_comm_stats(self):
@ -363,7 +372,7 @@ class CachedParamMgr(torch.nn.Module):
slot_offsets = slots slot_offsets = slots
self.cached_idx_map[slots] = cpu_row_idxs self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
if self._evict_strategy == EvictionStrategy.LFU : if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter.index_fill_(0, slots, 0) self.freq_cnter.index_fill_(0, slots, 0)
self._cuda_available_row_num -= cpu_row_idxs.numel() self._cuda_available_row_num -= cpu_row_idxs.numel()
self._cpu_to_cuda_elpase += timer.elapsed self._cpu_to_cuda_elpase += timer.elapsed
@ -407,7 +416,7 @@ class CachedParamMgr(torch.nn.Module):
# update inverted_cached_idx, min_slot_id is evicted from cuda # update inverted_cached_idx, min_slot_id is evicted from cuda
self.cached_idx_map[max_cpu_row_idx] = -1 self.cached_idx_map[max_cpu_row_idx] = -1
if self._evict_strategy == EvictionStrategy.LFU : if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[max_cpu_row_idx] = sys.maxsize self.freq_cnter[max_cpu_row_idx] = sys.maxsize
self.inverted_cached_idx[max_gpu_row_idx] = -1 self.inverted_cached_idx[max_gpu_row_idx] = -1
@ -443,7 +452,7 @@ class CachedParamMgr(torch.nn.Module):
# update the inverted_cached_idx # update the inverted_cached_idx
self.cached_idx_map[slot_id] = row_id self.cached_idx_map[slot_id] = row_id
if self._evict_strategy == EvictionStrategy.LFU : if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[slot_id] = 0 self.freq_cnter[slot_id] = 0
self.inverted_cached_idx[row_id] = slot_offset self.inverted_cached_idx[row_id] = slot_offset