mirror of https://github.com/hpcaitech/ColossalAI
[FAW] LFU cache for the FAW
parent
9145aef2b4
commit
b8d0e39eaf
|
@ -59,7 +59,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
"""_update_freq_cnter
|
||||
|
||||
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
||||
|
||||
|
||||
Args:
|
||||
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
||||
"""
|
||||
|
@ -80,7 +80,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# find the minimal evict_num freq entries in cached_idx_map
|
||||
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
|
||||
return self.cached_idx_map[evict_gpu_row_idxs]
|
||||
return evict_gpu_row_idxs
|
||||
elif self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# cached_idx_map itself implies the priority of eviction.
|
||||
# The value of self.cached_idx_map represents cpu_row_idx.
|
||||
|
@ -298,15 +298,27 @@ class CachedParamMgr(torch.nn.Module):
|
|||
if evict_num > 0:
|
||||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
# set cached_idx_map[invalid_idxs] to -2.
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
# another mask method.
|
||||
# set freq_cnter[invalid_idxs] to max
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_cnter = self.freq_cnter[invalid_idxs].clone()
|
||||
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
if self.buffer_size > 0:
|
||||
|
|
|
@ -144,6 +144,44 @@ def test_freq_aware_embed(use_LFU: bool):
|
|||
assert torch.allclose(model_weight, ref_weight), \
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
def test_lfu_strategy():
|
||||
# minimal test to check behavior
|
||||
Bag = FreqAwareEmbeddingBag(
|
||||
5,
|
||||
5,
|
||||
cuda_row_num=3,
|
||||
buffer_size=0,
|
||||
pin_weight=True,
|
||||
warmup_ratio=0.0,
|
||||
evict_strategy=EvictionStrategy.LFU
|
||||
)
|
||||
|
||||
offsets = torch.tensor([0],device="cuda:0")
|
||||
|
||||
# prepare frequency learning info:
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
||||
|
||||
# check strategy
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 1
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
|
||||
|
||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||
"LFU strategy behavior failed"
|
||||
|
||||
def gather_tensor(tensor, rank, world_size):
|
||||
gather_list = []
|
||||
|
@ -237,3 +275,4 @@ def test_parallel_freq_aware_embed(world_size):
|
|||
if __name__ == '__main__':
|
||||
test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
# test_lfu_strategy()
|
Loading…
Reference in New Issue