[FAW] LFU cache for the FAW

pull/1494/head
CsRic 2022-08-25 13:08:46 +08:00 committed by GitHub
parent 9145aef2b4
commit b8d0e39eaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 9 deletions

View File

@ -59,7 +59,7 @@ class CachedParamMgr(torch.nn.Module):
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
@ -80,7 +80,7 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
return self.cached_idx_map[evict_gpu_row_idxs]
return evict_gpu_row_idxs
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
@ -298,15 +298,27 @@ class CachedParamMgr(torch.nn.Module):
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
# another mask method.
# set freq_cnter[invalid_idxs] to max
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_cnter = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
if self.buffer_size > 0:

View File

@ -144,6 +144,44 @@ def test_freq_aware_embed(use_LFU: bool):
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
def test_lfu_strategy():
# minimal test to check behavior
Bag = FreqAwareEmbeddingBag(
5,
5,
cuda_row_num=3,
buffer_size=0,
pin_weight=True,
warmup_ratio=0.0,
evict_strategy=EvictionStrategy.LFU
)
offsets = torch.tensor([0],device="cuda:0")
# prepare frequency learning info:
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
# check strategy
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size):
gather_list = []
@ -237,3 +275,4 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_lfu_strategy()