[FAW] shrink freq_cnter size (#1509)

pull/1513/head
CsRic 2 years ago committed by GitHub
parent f8945eef17
commit 1b8fee8e9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@ class EvictionStrategy(Enum):
DATASET = 2
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
@ -46,7 +47,6 @@ class CachedParamMgr(torch.nn.Module):
self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size()
# weight configure
@ -61,31 +61,13 @@ class CachedParamMgr(torch.nn.Module):
self._evict_strategy = evict_strategy
if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
'''
The last element of `freq_cnter` is set to the maximum value of int.
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
In this way, the not used rows are placed at the end of the sorted.
'''
# cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings + 1,
torch.empty(self.cuda_row_num,
device=torch.cuda.current_device(),
dtype=torch.long).fill_(0),
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
self.freq_cnter[-1] = sys.maxsize
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if self._evict_strategy == EvictionStrategy.LFU:
add_num = torch.bincount(cpu_row_idxs_original)
self.freq_cnter[:add_num.shape[0]] += add_num
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs
@ -100,14 +82,15 @@ class CachedParamMgr(torch.nn.Module):
"""
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
_,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False)
return evict_gpu_row_idxs
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
return torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
_,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
return evict_gpu_row_idxs
else:
raise TypeError
@ -181,8 +164,7 @@ class CachedParamMgr(torch.nn.Module):
Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
@ -194,9 +176,6 @@ class CachedParamMgr(torch.nn.Module):
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx)
#initialize freq_cnter if use LFU
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0:
@ -218,6 +197,8 @@ class CachedParamMgr(torch.nn.Module):
# update auxiliary info
slot_offsets = preload_slot_ids
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter.index_fill_(0,preload_slot_ids,0)
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
self._cuda_available_row_num -= preload_row_num
print(f'Cache warmup finished cost {timer.elapsed} sec.')
@ -234,6 +215,8 @@ class CachedParamMgr(torch.nn.Module):
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel()
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter.fill_(sys.maxsize)
assert self._cuda_available_row_num == self.cuda_row_num
assert torch.all(self.inverted_cached_idx == -1).item()
assert torch.all(self.cached_idx_map == -1).item()
@ -275,8 +258,7 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
"""
with record_function("(zhg) get unique indices"):
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
@ -301,7 +283,10 @@ class CachedParamMgr(torch.nn.Module):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
self._update_freq_cnter(cpu_row_idxs_original)
if self._evict_strategy == EvictionStrategy.LFU :
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times)
return gpu_row_idxs
def _reset_comm_stats(self):
@ -324,23 +309,21 @@ class CachedParamMgr(torch.nn.Module):
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
backup_freqs = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
@ -357,6 +340,7 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
# self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary
self._cuda_available_row_num += evict_num
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
@ -379,6 +363,8 @@ class CachedParamMgr(torch.nn.Module):
slot_offsets = slots
self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter.index_fill_(0, slots, 0)
self._cuda_available_row_num -= cpu_row_idxs.numel()
self._cpu_to_cuda_elpase += timer.elapsed
weight_size = cpu_row_idxs.numel() * self.embedding_dim
@ -421,7 +407,8 @@ class CachedParamMgr(torch.nn.Module):
# update inverted_cached_idx, min_slot_id is evicted from cuda
self.cached_idx_map[max_cpu_row_idx] = -1
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter[max_cpu_row_idx] = sys.maxsize
self.inverted_cached_idx[max_gpu_row_idx] = -1
self._cuda_available_row_num += 1
@ -456,6 +443,8 @@ class CachedParamMgr(torch.nn.Module):
# update the inverted_cached_idx
self.cached_idx_map[slot_id] = row_id
if self._evict_strategy == EvictionStrategy.LFU :
self.freq_cnter[slot_id] = 0
self.inverted_cached_idx[row_id] = slot_offset
self._cuda_available_row_num -= 1

@ -177,9 +177,10 @@ def test_lfu_strategy():
# check strategy
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit

Loading…
Cancel
Save