mirror of https://github.com/hpcaitech/ColossalAI
[FAW] refactor reorder() for CachedParamMgr (#1514)
parent
9feee6d06b
commit
af5438caa2
|
@ -172,44 +172,53 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
if ids_freq_mapping is not None:
|
# reorder phase: reorder the cpu weight according to their freq stats in the target dataset.
|
||||||
if not isinstance(ids_freq_mapping, torch.Tensor):
|
# reorder only works for DATASET eviction strategy.
|
||||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
|
||||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
|
||||||
sorted_idx = torch.argsort(tmp_idx)
|
|
||||||
self.idx_map.data.copy_(sorted_idx)
|
|
||||||
|
|
||||||
|
if ids_freq_mapping is not None and not isinstance(ids_freq_mapping, torch.Tensor):
|
||||||
|
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||||
|
|
||||||
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||||
|
if ids_freq_mapping is not None:
|
||||||
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||||
|
sorted_idx = torch.argsort(tmp_idx)
|
||||||
|
self.idx_map.data.copy_(sorted_idx)
|
||||||
|
|
||||||
|
# warmup phase: copy #preload_row_num rows from cpu to gpu.
|
||||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||||
if preload_row_num > 0:
|
if preload_row_num > 0:
|
||||||
with Timer() as timer:
|
with Timer() as timer:
|
||||||
# extract rows from cpu weight
|
# extract rows from cpu weight
|
||||||
preload_row_ids = torch.arange(preload_row_num)
|
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
||||||
preload_cuda_row_idxs = preload_row_ids.cuda()
|
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
||||||
|
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
|
||||||
|
else:
|
||||||
|
preload_cpu_ids = torch.arange(preload_row_num)
|
||||||
|
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
||||||
|
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
self.limit_buff_index_copyer.index_copy(0,
|
self.limit_buff_index_copyer.index_copy(0,
|
||||||
src_index=preload_row_ids,
|
src_index=preload_cpu_ids,
|
||||||
tgt_index=preload_cuda_row_idxs,
|
tgt_index=preload_cuda_row_idxs,
|
||||||
src=self.weight.view(self.num_embeddings, -1),
|
src=self.weight.view(self.num_embeddings, -1),
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
||||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
||||||
preload_rows)
|
preload_rows)
|
||||||
|
|
||||||
# update auxiliary info
|
# update auxiliary info
|
||||||
slot_offsets = preload_cuda_row_idxs
|
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
|
||||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
|
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
||||||
|
self._cuda_available_row_num -= preload_row_num
|
||||||
|
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
|
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
|
||||||
if ids_freq_mapping is None:
|
if ids_freq_mapping is None:
|
||||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
||||||
else:
|
else:
|
||||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
|
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
|
||||||
|
|
||||||
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
|
|
||||||
self._cuda_available_row_num -= preload_row_num
|
|
||||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
|
|
|
@ -144,49 +144,52 @@ def test_freq_aware_embed(use_LFU: bool):
|
||||||
assert torch.allclose(model_weight, ref_weight), \
|
assert torch.allclose(model_weight, ref_weight), \
|
||||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||||
|
|
||||||
def test_lfu_strategy():
|
|
||||||
# minimal test to check behavior
|
|
||||||
Bag = FreqAwareEmbeddingBag(
|
|
||||||
5,
|
|
||||||
5,
|
|
||||||
cuda_row_num=3,
|
|
||||||
buffer_size=0,
|
|
||||||
pin_weight=True,
|
|
||||||
warmup_ratio=0.0,
|
|
||||||
evict_strategy=EvictionStrategy.LFU
|
|
||||||
)
|
|
||||||
|
|
||||||
offsets = torch.tensor([0],device="cuda:0")
|
@pytest.mark.parametrize('init_freq', [True, False])
|
||||||
|
def test_lfu_strategy(init_freq: bool):
|
||||||
|
# minimal test to check behavior
|
||||||
|
Bag = FreqAwareEmbeddingBag(5,
|
||||||
|
5,
|
||||||
|
cuda_row_num=3,
|
||||||
|
buffer_size=0,
|
||||||
|
pin_weight=True,
|
||||||
|
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
||||||
|
warmup_ratio=1.0,
|
||||||
|
evict_strategy=EvictionStrategy.LFU)
|
||||||
|
|
||||||
|
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
||||||
|
offsets = torch.tensor([0], device="cuda:0")
|
||||||
|
|
||||||
# prepare frequency learning info:
|
# prepare frequency learning info:
|
||||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
||||||
|
|
||||||
# check strategy
|
# check strategy
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
||||||
Bag.forward(torch.tensor([3],device="cuda:0"),offsets) # miss, evict 1
|
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
|
||||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||||
Bag.forward(torch.tensor([4],device="cuda:0"),offsets) # miss, evict 3
|
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
|
||||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets) # hit
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
||||||
Bag.forward(torch.tensor([0],device="cuda:0"),offsets) # hit
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
|
||||||
|
|
||||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||||
"LFU strategy behavior failed"
|
"LFU strategy behavior failed"
|
||||||
|
|
||||||
|
|
||||||
def gather_tensor(tensor, rank, world_size):
|
def gather_tensor(tensor, rank, world_size):
|
||||||
gather_list = []
|
gather_list = []
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -279,4 +282,4 @@ def test_parallel_freq_aware_embed(world_size):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# test_freq_aware_embed(True)
|
# test_freq_aware_embed(True)
|
||||||
# test_parallel_freq_aware_embed(2)
|
# test_parallel_freq_aware_embed(2)
|
||||||
test_lfu_strategy()
|
test_lfu_strategy(False)
|
||||||
|
|
Loading…
Reference in New Issue