|
|
|
@ -14,6 +14,7 @@ class EvictionStrategy(Enum):
|
|
|
|
|
DATASET = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
|
|
|
@ -46,7 +47,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.cuda_row_num = cuda_row_num
|
|
|
|
|
self._cuda_available_row_num = self.cuda_row_num
|
|
|
|
|
self.pin_weight = pin_weight
|
|
|
|
|
|
|
|
|
|
self.elem_size_in_byte = weight.element_size()
|
|
|
|
|
|
|
|
|
|
# weight configure
|
|
|
|
@ -61,31 +61,13 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self._evict_strategy = evict_strategy
|
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
# cpu_row_idx -> frequency, freq of the cpu rows.
|
|
|
|
|
# evict the minimal freq value row in cuda cache.
|
|
|
|
|
'''
|
|
|
|
|
The last element of `freq_cnter` is set to the maximum value of int.
|
|
|
|
|
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
|
|
|
|
|
In this way, the not used rows are placed at the end of the sorted.
|
|
|
|
|
'''
|
|
|
|
|
# cache_row_idx -> frequency, freq of the cache rows.
|
|
|
|
|
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
|
|
|
|
self.register_buffer("freq_cnter",
|
|
|
|
|
torch.empty(self.num_embeddings + 1,
|
|
|
|
|
torch.empty(self.cuda_row_num,
|
|
|
|
|
device=torch.cuda.current_device(),
|
|
|
|
|
dtype=torch.long).fill_(0),
|
|
|
|
|
dtype=torch.long).fill_(sys.maxsize),
|
|
|
|
|
persistent=False)
|
|
|
|
|
self.freq_cnter[-1] = sys.maxsize
|
|
|
|
|
|
|
|
|
|
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
|
|
|
|
|
"""_update_freq_cnter
|
|
|
|
|
|
|
|
|
|
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
|
|
|
|
"""
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
add_num = torch.bincount(cpu_row_idxs_original)
|
|
|
|
|
self.freq_cnter[:add_num.shape[0]] += add_num
|
|
|
|
|
|
|
|
|
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
|
|
|
|
"""_find_evict_gpu_idxs
|
|
|
|
@ -100,14 +82,15 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
# find the minimal evict_num freq entries in cached_idx_map
|
|
|
|
|
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
|
|
|
|
|
_,evict_gpu_row_idxs = torch.topk(self.freq_cnter,evict_num,largest=False)
|
|
|
|
|
return evict_gpu_row_idxs
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.DATASET:
|
|
|
|
|
# cached_idx_map itself implies the priority of eviction.
|
|
|
|
|
# The value of self.cached_idx_map represents cpu_row_idx.
|
|
|
|
|
# The larger it is, the less frequently it will appear in the dataset,
|
|
|
|
|
# and the higher its eviction priority will be.
|
|
|
|
|
return torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
|
|
|
|
|
_,evict_gpu_row_idxs = torch.topk(self.cached_idx_map, evict_num, largest=True)
|
|
|
|
|
return evict_gpu_row_idxs
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError
|
|
|
|
|
|
|
|
|
@ -181,8 +164,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
Execute only once before training, also known as warmup phase.
|
|
|
|
|
|
|
|
|
|
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
|
|
|
|
|
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
|
|
|
|
|
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
|
|
|
|
|
:NOTE If you are use the LFU as the eviction strategy, you can skip this function.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
|
|
|
@ -194,9 +176,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
|
|
|
|
sorted_idx = torch.argsort(tmp_idx)
|
|
|
|
|
self.idx_map.data.copy_(sorted_idx)
|
|
|
|
|
#initialize freq_cnter if use LFU
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
|
|
|
|
|
|
|
|
|
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
|
|
|
|
if preload_row_num > 0:
|
|
|
|
@ -218,6 +197,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# update auxiliary info
|
|
|
|
|
slot_offsets = preload_slot_ids
|
|
|
|
|
self.cached_idx_map[preload_slot_ids] = preload_slot_ids
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
self.freq_cnter.index_fill_(0,preload_slot_ids,0)
|
|
|
|
|
self.inverted_cached_idx[preload_slot_ids] = slot_offsets
|
|
|
|
|
self._cuda_available_row_num -= preload_row_num
|
|
|
|
|
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
|
|
|
@ -234,6 +215,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
|
|
|
|
self._cuda_available_row_num += slots.numel()
|
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
self.freq_cnter.fill_(sys.maxsize)
|
|
|
|
|
assert self._cuda_available_row_num == self.cuda_row_num
|
|
|
|
|
assert torch.all(self.inverted_cached_idx == -1).item()
|
|
|
|
|
assert torch.all(self.cached_idx_map == -1).item()
|
|
|
|
@ -275,8 +258,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
torch.Tensor: indices on the cuda_cached_weight.
|
|
|
|
|
"""
|
|
|
|
|
with record_function("(zhg) get unique indices"):
|
|
|
|
|
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
|
|
|
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
|
|
|
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts = True)
|
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
|
|
|
|
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
|
|
|
|
@ -301,7 +283,10 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
|
|
|
|
|
|
|
|
|
# update for LFU.
|
|
|
|
|
self._update_freq_cnter(cpu_row_idxs_original)
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
|
|
|
|
self.freq_cnter.scatter_add_(0,unique_gpu_row_idxs,repeat_times)
|
|
|
|
|
|
|
|
|
|
return gpu_row_idxs
|
|
|
|
|
|
|
|
|
|
def _reset_comm_stats(self):
|
|
|
|
@ -324,23 +309,21 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if evict_num > 0:
|
|
|
|
|
with Timer() as timer:
|
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
|
|
|
|
|
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
|
|
|
|
# mask method.
|
|
|
|
|
# set cached_idx_map[invalid_idxs] to -2.
|
|
|
|
|
# so those idxs will be sorted to end, therefore not being chosen as victim
|
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
|
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
|
|
|
|
|
|
|
|
@ -357,6 +340,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
|
|
|
|
self.inverted_cached_idx.index_fill_(0, evict_info, -1)
|
|
|
|
|
# self.freq_cnter.index_fill(0, evict_gpu_row_idxs, sys.maxsize) # unnecessary
|
|
|
|
|
self._cuda_available_row_num += evict_num
|
|
|
|
|
|
|
|
|
|
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
@ -379,6 +363,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
slot_offsets = slots
|
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs
|
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
self.freq_cnter.index_fill_(0, slots, 0)
|
|
|
|
|
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
|
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed
|
|
|
|
|
weight_size = cpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
@ -421,7 +407,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
# update inverted_cached_idx, min_slot_id is evicted from cuda
|
|
|
|
|
self.cached_idx_map[max_cpu_row_idx] = -1
|
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
self.freq_cnter[max_cpu_row_idx] = sys.maxsize
|
|
|
|
|
self.inverted_cached_idx[max_gpu_row_idx] = -1
|
|
|
|
|
|
|
|
|
|
self._cuda_available_row_num += 1
|
|
|
|
@ -456,6 +443,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
# update the inverted_cached_idx
|
|
|
|
|
self.cached_idx_map[slot_id] = row_id
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU :
|
|
|
|
|
self.freq_cnter[slot_id] = 0
|
|
|
|
|
self.inverted_cached_idx[row_id] = slot_offset
|
|
|
|
|
|
|
|
|
|
self._cuda_available_row_num -= 1
|
|
|
|
|