mirror of https://github.com/hpcaitech/ColossalAI
[FAW] LFU initialize with dataset freq (#1513)
parent
1b8fee8e9c
commit
9feee6d06b
|
@ -14,7 +14,6 @@ class EvictionStrategy(Enum):
|
|||
DATASET = 2
|
||||
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||
|
@ -64,8 +63,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# cache_row_idx -> frequency, freq of the cache rows.
|
||||
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
||||
self.register_buffer("freq_cnter",
|
||||
torch.empty(self.cuda_row_num,
|
||||
device=torch.cuda.current_device(),
|
||||
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(sys.maxsize),
|
||||
persistent=False)
|
||||
|
||||
|
@ -82,14 +80,14 @@ class CachedParamMgr(torch.nn.Module):
|
|||
"""
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# find the minimal evict_num freq entries in cached_idx_map
|
||||
_,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False)
|
||||
_, evict_gpu_row_idxs = torch.topk(self.freq_cnter, evict_num, largest=False)
|
||||
return evict_gpu_row_idxs
|
||||
elif self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# cached_idx_map itself implies the priority of eviction.
|
||||
# The value of self.cached_idx_map represents cpu_row_idx.
|
||||
# The larger it is, the less frequently it will appear in the dataset,
|
||||
# and the higher its eviction priority will be.
|
||||
_,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
|
||||
_, evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
|
||||
return evict_gpu_row_idxs
|
||||
else:
|
||||
raise TypeError
|
||||
|
@ -163,8 +161,12 @@ class CachedParamMgr(torch.nn.Module):
|
|||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Execute only once before training, also known as warmup phase.
|
||||
|
||||
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
|
||||
Note:
|
||||
If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||
|
||||
Note:
|
||||
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
|
||||
The frequency in LFU cache using the dataset statistics.
|
||||
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||
|
@ -182,24 +184,31 @@ class CachedParamMgr(torch.nn.Module):
|
|||
with Timer() as timer:
|
||||
# extract rows from cpu weight
|
||||
preload_row_ids = torch.arange(preload_row_num)
|
||||
preload_slot_ids = preload_row_ids.cuda()
|
||||
preload_cuda_row_idxs = preload_row_ids.cuda()
|
||||
|
||||
if self.buffer_size > 0:
|
||||
self.limit_buff_index_copyer.index_copy(0,
|
||||
src_index=preload_row_ids,
|
||||
tgt_index=preload_slot_ids,
|
||||
tgt_index=preload_cuda_row_idxs,
|
||||
src=self.weight.view(self.num_embeddings, -1),
|
||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
|
||||
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
||||
preload_rows)
|
||||
|
||||
# update auxiliary info
|
||||
slot_offsets = preload_slot_ids
|
||||
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
self.freq_cnter.index_fill_(0,preload_slot_ids,0)
|
||||
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
|
||||
slot_offsets = preload_cuda_row_idxs
|
||||
self.cached_idx_map[preload_cuda_row_idxs] = preload_cuda_row_idxs
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# if the ids_freq_mapping is not None, we initialize the embedding row's freq value in LFU as its freq in dataset.
|
||||
if ids_freq_mapping is None:
|
||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
||||
else:
|
||||
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, self.idx_map[preload_cuda_row_idxs])
|
||||
|
||||
self.inverted_cached_idx[preload_cuda_row_idxs] = slot_offsets
|
||||
self._cuda_available_row_num -= preload_row_num
|
||||
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
||||
|
||||
|
@ -215,7 +224,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
||||
self._cuda_available_row_num += slots.numel()
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter.fill_(sys.maxsize)
|
||||
assert self._cuda_available_row_num == self.cuda_row_num
|
||||
assert torch.all(self.inverted_cached_idx == -1).item()
|
||||
|
@ -258,7 +267,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
torch.Tensor: indices on the cuda_cached_weight.
|
||||
"""
|
||||
with record_function("(zhg) get unique indices"):
|
||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True)
|
||||
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
|
||||
|
@ -283,10 +292,10 @@ class CachedParamMgr(torch.nn.Module):
|
|||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
|
||||
# update for LFU.
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
||||
self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times)
|
||||
|
||||
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
||||
|
||||
return gpu_row_idxs
|
||||
|
||||
def _reset_comm_stats(self):
|
||||
|
@ -363,7 +372,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
slot_offsets = slots
|
||||
self.cached_idx_map[slots] = cpu_row_idxs
|
||||
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter.index_fill_(0, slots, 0)
|
||||
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
||||
self._cpu_to_cuda_elpase += timer.elapsed
|
||||
|
@ -407,7 +416,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
|
||||
# update inverted_cached_idx, min_slot_id is evicted from cuda
|
||||
self.cached_idx_map[max_cpu_row_idx] = -1
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[max_cpu_row_idx] = sys.maxsize
|
||||
self.inverted_cached_idx[max_gpu_row_idx] = -1
|
||||
|
||||
|
@ -443,7 +452,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
|
||||
# update the inverted_cached_idx
|
||||
self.cached_idx_map[slot_id] = row_id
|
||||
if self._evict_strategy == EvictionStrategy.LFU :
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[slot_id] = 0
|
||||
self.inverted_cached_idx[row_id] = slot_offset
|
||||
|
||||
|
|
Loading…
Reference in New Issue