|
|
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
|
|
|
|
from contexttimer import Timer |
|
|
|
|
from .copyer import LimitBuffIndexCopyer |
|
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
|
class EvictionStrategy(Enum): |
|
|
|
|
LFU = 1 |
|
|
|
@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
cuda_row_num: int = 0, |
|
|
|
|
buffer_size: int = 50_000, |
|
|
|
|
pin_weight=False, |
|
|
|
|
evict_strategy=EvictionStrategy.DATASET) -> None: |
|
|
|
|
evict_strategy=EvictionStrategy.DATASET,) -> None: |
|
|
|
|
super(CachedParamMgr, self).__init__() |
|
|
|
|
self.buffer_size = buffer_size |
|
|
|
|
self.num_embeddings, self.embedding_dim = weight.shape |
|
|
|
|
self.cuda_row_num = cuda_row_num |
|
|
|
|
self._cuda_available_row_num = self.cuda_row_num |
|
|
|
|
self.pin_weight = pin_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.elem_size_in_byte = weight.element_size() |
|
|
|
|
|
|
|
|
|
# weight configure |
|
|
|
@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
# cpu_row_idx -> frequency, freq of the cpu rows. |
|
|
|
|
# evict the minimal freq value row in cuda cache. |
|
|
|
|
|
|
|
|
|
''' |
|
|
|
|
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary. |
|
|
|
|
also, disabled cached_idx_map element maps to -1 by default. |
|
|
|
|
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end, |
|
|
|
|
not being chosen to evict. |
|
|
|
|
|
|
|
|
|
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出 |
|
|
|
|
''' |
|
|
|
|
self.register_buffer("freq_cnter", |
|
|
|
|
torch.empty(self.num_embeddings, device=torch.cuda.current_device(), |
|
|
|
|
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(), |
|
|
|
|
dtype=torch.long).fill_(0), |
|
|
|
|
persistent=False) |
|
|
|
|
self.freq_cnter[-1] = sys.maxsize |
|
|
|
|
|
|
|
|
|
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None: |
|
|
|
|
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None: |
|
|
|
|
"""_update_freq_cnter |
|
|
|
|
|
|
|
|
|
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. |
|
|
|
@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. |
|
|
|
|
""" |
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
self.freq_cnter[cpu_row_idxs] += 1 |
|
|
|
|
add_num = torch.bincount(cpu_row_idxs_original) |
|
|
|
|
self.freq_cnter[:add_num.shape[0]] += add_num |
|
|
|
|
|
|
|
|
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: |
|
|
|
|
"""_find_evict_gpu_idxs |
|
|
|
@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache |
|
|
|
|
""" |
|
|
|
|
if ids_freq_mapping is not None: |
|
|
|
|
ids_freq_mapping = torch.tensor(ids_freq_mapping) |
|
|
|
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True) |
|
|
|
|
sorted_idx = torch.argsort(tmp_idx) |
|
|
|
|
self.idx_map.data.copy_(sorted_idx) |
|
|
|
|
|
|
|
|
|
#initialize freq_cnter if use LFU |
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping) |
|
|
|
|
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks. |
|
|
|
|
# As cuda_cached_weight is very big. You may not have that much available memory! |
|
|
|
|
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda |
|
|
|
@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
torch.Tensor: indices on the cuda_cached_weight. |
|
|
|
|
""" |
|
|
|
|
with record_function("(zhg) get unique indices"): |
|
|
|
|
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids)) |
|
|
|
|
|
|
|
|
|
cpu_row_idxs_original = self.idx_map.index_select(0, ids) |
|
|
|
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original) |
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \ |
|
|
|
|
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ |
|
|
|
|
f"which is larger than the presented {self.cuda_row_num}, " \ |
|
|
|
@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# new ids chunk_offset + offset_in_chunk |
|
|
|
|
with record_function("(zhg) embed idx -> cache chunk id"): |
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# update for LFU. |
|
|
|
|
self._update_freq_cnter(cpu_row_idxs) |
|
|
|
|
|
|
|
|
|
self._update_freq_cnter(cpu_row_idxs_original) |
|
|
|
|
return gpu_row_idxs |
|
|
|
|
|
|
|
|
|
def _reset_comm_stats(self): |
|
|
|
@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if evict_num > 0: |
|
|
|
|
with Timer() as timer: |
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) |
|
|
|
|
|
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) |
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.DATASET: |
|
|
|
|
# mask method. |
|
|
|
|
# set cached_idx_map[invalid_idxs] to -2. |
|
|
|
|
# so those idxs will be sorted to end, therefore not being chosen as victim |
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) |
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() |
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) |
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
|
# another mask method. |
|
|
|
|
# set freq_cnter[invalid_idxs] to max |
|
|
|
|
# so those idxs will be sorted to end, therefore not being chosen as victim |
|
|
|
|
backup_cnter = self.freq_cnter[invalid_idxs].clone() |
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value? |
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) |
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() |
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -1) |
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) |
|
|
|
|
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter) |
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs] |
|
|
|
|
|
|
|
|
|