|
|
|
@ -20,15 +20,15 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
CPU maintains the entire original weight.
|
|
|
|
|
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
|
|
|
|
|
During training, GPU needs to transmit embedding rows between CPU and GPU.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
weight (torch.Tensor): the weight of the Embedding layer.
|
|
|
|
|
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
|
|
|
|
|
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
|
|
|
|
|
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
|
|
|
|
|
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. `EvictionStrategy.LFU` uses the least frequently used cache. `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
|
|
|
|
Default as EvictionStrategy.DATASET.
|
|
|
|
|
use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu.
|
|
|
|
|
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
|
|
|
|
|
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
|
|
|
|
|
`EvictionStrategy.LFU`: use the least frequently used cache.
|
|
|
|
|
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
|
|
|
|
Defaults to EvictionStrategy.DATASET.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
@ -38,7 +38,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
buffer_size: int = 0,
|
|
|
|
|
pin_weight: bool = False,
|
|
|
|
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
|
|
|
|
use_cpu_caching=False,
|
|
|
|
|
) -> None:
|
|
|
|
|
super(CachedParamMgr, self).__init__()
|
|
|
|
|
self.buffer_size = buffer_size
|
|
|
|
@ -48,13 +47,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.pin_weight = pin_weight
|
|
|
|
|
self.elem_size_in_byte = weight.element_size()
|
|
|
|
|
|
|
|
|
|
self._cpu_caching = use_cpu_caching
|
|
|
|
|
|
|
|
|
|
if self._cpu_caching:
|
|
|
|
|
self._cache_dev = torch.device('cpu')
|
|
|
|
|
else:
|
|
|
|
|
self._cache_dev = torch.cuda.current_device()
|
|
|
|
|
|
|
|
|
|
# weight configure
|
|
|
|
|
self._init_weight(weight)
|
|
|
|
|
|
|
|
|
@ -69,24 +61,16 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
# cache_row_idx -> frequency, freq of the cache rows.
|
|
|
|
|
# classic lfu cache. evict the minimal freq value row in cuda cache.
|
|
|
|
|
if self._cpu_caching:
|
|
|
|
|
self.freq_cnter = torch.empty(self.cuda_row_num, device=self._cache_dev,
|
|
|
|
|
dtype=torch.long).fill_(sys.maxsize)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
self.register_buffer("freq_cnter",
|
|
|
|
|
torch.empty(self.cuda_row_num, device=self._cache_dev,
|
|
|
|
|
dtype=torch.long).fill_(sys.maxsize),
|
|
|
|
|
persistent=False)
|
|
|
|
|
self.register_buffer("freq_cnter",
|
|
|
|
|
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
|
|
|
|
dtype=torch.long).fill_(sys.maxsize),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
|
|
|
|
"""_find_evict_gpu_idxs
|
|
|
|
|
|
|
|
|
|
Find the gpu idxs to be evicted, according to their freq.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
evict_num (int): how many rows has to be evicted
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
|
|
|
|
|
"""
|
|
|
|
@ -117,32 +101,26 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.weight = weight.pin_memory() if self.pin_weight else weight
|
|
|
|
|
# map original id to new id with respect to frequency
|
|
|
|
|
# id -> cpu_row_idx
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"idx_map",
|
|
|
|
|
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self._cpu_caching:
|
|
|
|
|
self.idx_map = torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev)
|
|
|
|
|
self.cached_idx_map = torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1)
|
|
|
|
|
self.inverted_cached_idx = torch.zeros(self.num_embeddings, device=self._cache_dev,
|
|
|
|
|
dtype=torch.long).fill_(-1)
|
|
|
|
|
else:
|
|
|
|
|
self.register_buffer(
|
|
|
|
|
"idx_map",
|
|
|
|
|
torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev),
|
|
|
|
|
persistent=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
|
|
|
|
self.register_buffer("cached_idx_map",
|
|
|
|
|
torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
# cpu_row_id -> gpu_row_idx.
|
|
|
|
|
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
|
|
|
|
self.register_buffer("inverted_cached_idx",
|
|
|
|
|
torch.zeros(self.num_embeddings, device=self._cache_dev,
|
|
|
|
|
dtype=torch.long).fill_(-1),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=self._cache_dev)
|
|
|
|
|
# cached_idx_map: gpu_row_idx -> cpu_row_idx
|
|
|
|
|
self.register_buffer("cached_idx_map",
|
|
|
|
|
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
|
|
|
|
dtype=torch.long).fill_(-1),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
# cpu_row_id -> gpu_row_idx.
|
|
|
|
|
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
|
|
|
|
|
self.register_buffer("inverted_cached_idx",
|
|
|
|
|
torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
|
|
|
|
|
dtype=torch.long).fill_(-1),
|
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
|
|
|
|
|
|
|
|
|
|
# index copy buffer size should less than 10% of cuda weight.
|
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
@ -157,10 +135,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
access a row of CPU weight.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
row_idx (int): the idx of rows
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
|
|
|
|
|
"""
|
|
|
|
@ -178,14 +154,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""reorder
|
|
|
|
|
reorder the weight according to ids' frequency in dataset before training.
|
|
|
|
|
Execute only once before training, also known as warmup phase.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If you would like to use the DATASET as the eviction strategy, you must call this function.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
|
|
|
|
|
The frequency in LFU cache using the dataset statistics.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
|
|
|
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
|
|
|
@ -209,24 +183,24 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# extract rows from cpu weight
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
|
|
|
|
|
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
|
|
|
|
|
preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev)
|
|
|
|
|
preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
|
|
|
|
|
else:
|
|
|
|
|
preload_cpu_ids = torch.arange(preload_row_num, device=self.weight.device)
|
|
|
|
|
preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev)
|
|
|
|
|
preload_cpu_ids = torch.arange(preload_row_num)
|
|
|
|
|
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
|
|
|
|
|
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
|
src_index=preload_cpu_ids,
|
|
|
|
|
tgt_index=preload_cuda_row_idxs.cuda(),
|
|
|
|
|
tgt_index=preload_cuda_row_idxs,
|
|
|
|
|
src=self.weight.view(self.num_embeddings, -1),
|
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
|
|
|
|
else:
|
|
|
|
|
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs.cuda(),
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
|
|
|
|
|
preload_rows)
|
|
|
|
|
|
|
|
|
|
# update auxiliary info
|
|
|
|
|
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.to(self._cache_dev)
|
|
|
|
|
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
|
|
|
|
|
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
|
|
|
|
|
self._cuda_available_row_num -= preload_row_num
|
|
|
|
|
|
|
|
|
@ -235,7 +209,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if ids_freq_mapping is None:
|
|
|
|
|
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
|
|
|
|
|
else:
|
|
|
|
|
self.freq_cnter[preload_cuda_row_idxs] = freq_value.to(self._cache_dev)
|
|
|
|
|
self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
|
|
|
|
|
|
|
|
|
|
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
|
|
|
|
|
|
|
|
@ -245,7 +219,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
|
|
|
|
row_ids = self.cached_idx_map[slots]
|
|
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots.cuda()).cpu()
|
|
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
|
|
|
|
self.cached_idx_map.index_fill_(0, slots, -1)
|
|
|
|
|
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
|
|
|
@ -272,10 +246,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
convert ids to indices in self.cuda_cached_weight.
|
|
|
|
|
Implemented with parallel operations on GPU.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ids (torch.Tensor): ids from the dataset
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: contains indices in self.cuda_cached_weight
|
|
|
|
|
"""
|
|
|
|
@ -287,14 +259,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
move the cpu embedding rows w.r.t. ids into CUDA memory
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ids (torch.Tensor): the ids to be computed
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: indices on the cuda_cached_weight.
|
|
|
|
|
"""
|
|
|
|
|
with record_function(f"(pre-id) get unique indices. cache ratio {self.cuda_row_num / self.num_embeddings}"):
|
|
|
|
|
ids = ids.to(self._cache_dev)
|
|
|
|
|
with record_function("(zhg) get unique indices"):
|
|
|
|
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
|
|
|
@ -303,29 +273,26 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
f"Please increase cuda_row_num or decrease the training batch size."
|
|
|
|
|
self.evict_backlist = cpu_row_idxs
|
|
|
|
|
|
|
|
|
|
with record_function("(pre-id) get cpu row idxs"):
|
|
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs,
|
|
|
|
|
self.cached_idx_map,
|
|
|
|
|
assume_unique=True,
|
|
|
|
|
invert=True)]
|
|
|
|
|
with record_function("(zhg) get cpu row idxs"):
|
|
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
|
|
|
|
|
|
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
|
|
|
|
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
|
|
|
|
self.num_write_back_history.append(0)
|
|
|
|
|
|
|
|
|
|
# move sure the cuda rows will not be evicted!
|
|
|
|
|
with record_function("(pre-id) cache update"):
|
|
|
|
|
with record_function("(zhg) cache update"):
|
|
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
|
|
|
|
|
|
|
|
|
with record_function("(pre-id) embed cpu rows idx -> cache gpu row idxs"):
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
|
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
|
|
|
|
|
|
|
|
|
# update for LFU.
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
|
with record_function("(pre-id) lfu cnter updates"):
|
|
|
|
|
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
|
|
|
|
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
|
|
|
|
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
|
|
|
|
|
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
|
|
|
|
|
|
|
|
|
|
return gpu_row_idxs
|
|
|
|
|
|
|
|
|
@ -341,14 +308,13 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
|
|
|
|
"""prepare rows in cpu_row_idxs on CUDA memory
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
|
|
|
|
"""
|
|
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
|
|
|
|
if evict_num > 0:
|
|
|
|
|
with Timer() as timer:
|
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True)
|
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
|
|
|
|
# mask method.
|
|
|
|
@ -375,8 +341,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
tgt=self.weight.view(self.num_embeddings, -1))
|
|
|
|
|
else:
|
|
|
|
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
|
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num,
|
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs.cuda()).cpu()
|
|
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
|
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
|
|
|
@ -395,12 +360,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
|
src_index=cpu_row_idxs.cpu(),
|
|
|
|
|
tgt_index=slots.cuda(),
|
|
|
|
|
tgt_index=slots,
|
|
|
|
|
src=self.weight.view(self.num_embeddings, -1),
|
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
|
|
|
|
else:
|
|
|
|
|
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots.cuda(), rows)
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
|
|
|
|
|
slot_offsets = slots
|
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs
|
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
|
|
|
@ -421,7 +386,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
def _evict(self) -> int:
|
|
|
|
|
"""
|
|
|
|
|
deprecated
|
|
|
|
|
|
|
|
|
|
evict one row from cuda to cpu.
|
|
|
|
|
Returns:
|
|
|
|
|
(int) : the slot id be evicted.
|
|
|
|
@ -463,9 +427,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
def _admit(self, row_id: int):
|
|
|
|
|
"""
|
|
|
|
|
deprecated
|
|
|
|
|
|
|
|
|
|
move in row_id to CUDA
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
row_id (int): the id of row to be moved in
|
|
|
|
|
"""
|
|
|
|
@ -491,4 +453,4 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self._cuda_available_row_num -= 1
|
|
|
|
|
|
|
|
|
|
self._cpu_to_cuda_numel += self.embedding_dim
|
|
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed
|
|
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed
|