|
|
|
@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self.num_hits_history = []
|
|
|
|
|
self.num_miss_history = []
|
|
|
|
|
self.num_write_back_history = []
|
|
|
|
|
self.input_id_percent_in_load_chunk = []
|
|
|
|
|
self._reset_comm_stats()
|
|
|
|
|
|
|
|
|
|
self._evict_strategy = evict_strategy
|
|
|
|
@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# self.cuda_cached_weight = self.weight
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
|
|
|
|
|
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
|
|
|
|
|
"""
|
|
|
|
|
access a chunk of CPU weight.
|
|
|
|
|
access a row of CPU weight.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
chunk_id (int): chunk id
|
|
|
|
|
row_idx (int): the idx of rows
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
|
|
|
|
|
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
return self.weight.data.view(-1).narrow(0,
|
|
|
|
|
int(chunk_id) * self.embedding_dim,
|
|
|
|
|
int(row_idx) * self.embedding_dim,
|
|
|
|
|
self.embedding_dim).view(1, self.embedding_dim)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def cuda_available_chunk_num(self):
|
|
|
|
|
def cuda_available_row_num(self):
|
|
|
|
|
return self._cuda_available_row_num
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
|
|
|
|
if preload_row_num > 0:
|
|
|
|
|
with Timer() as timer:
|
|
|
|
|
# extract chunks from cpu weight
|
|
|
|
|
# extract rows from cpu weight
|
|
|
|
|
preload_row_ids = torch.arange(preload_row_num)
|
|
|
|
|
preload_slot_ids = preload_row_ids.cuda()
|
|
|
|
|
|
|
|
|
@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
src=self.weight.view(self.num_embeddings, -1),
|
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
|
|
|
|
else:
|
|
|
|
|
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
|
|
|
|
|
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
|
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
|
|
|
|
|
|
|
|
|
|
# update auxiliary info
|
|
|
|
|
slot_offsets = preload_slot_ids
|
|
|
|
@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
print(f'Cache warmup finished cost {timer.elapsed} sec.')
|
|
|
|
|
|
|
|
|
|
def flush(self):
|
|
|
|
|
"""flush all CUDA chunks to CPU.
|
|
|
|
|
"""flush all CUDA rows to CPU.
|
|
|
|
|
The function is usually called after training finished.
|
|
|
|
|
"""
|
|
|
|
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
|
|
|
|
|
chunk_ids = self.cached_idx_map[slots]
|
|
|
|
|
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
|
|
|
|
|
row_ids = self.cached_idx_map[slots]
|
|
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
|
|
|
|
|
self.cached_idx_map.index_fill_(0, slots, -1)
|
|
|
|
|
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
|
|
|
|
|
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
|
|
|
|
|
self._cuda_available_row_num += slots.numel()
|
|
|
|
|
|
|
|
|
|
assert self._cuda_available_row_num == self.cuda_row_num
|
|
|
|
@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
|
|
|
|
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
|
|
|
|
f"which is larger than the presented {self.cuda_row_num}, " \
|
|
|
|
|
f"please increase cuda_row_num shrink batch size"
|
|
|
|
|
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
|
|
|
|
|
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
|
|
|
|
|
f"Please increase cuda_row_num or decrease the training batch size."
|
|
|
|
|
self.evict_backlist = cpu_row_idxs
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) get cpu chunk indices"):
|
|
|
|
|
with record_function("(zhg) get cpu row idxs"):
|
|
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
|
|
|
|
|
|
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
|
|
|
|
self.num_miss_history.append(len(comm_cpu_row_idxs))
|
|
|
|
|
self.num_write_back_history.append(0)
|
|
|
|
|
|
|
|
|
|
# move sure the cuda chunk will not be evicted!
|
|
|
|
|
# move sure the cuda rows will not be evicted!
|
|
|
|
|
with record_function("(zhg) cache update"):
|
|
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
|
|
|
|
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
|
|
|
|
# new ids chunk_offset + offset_in_chunk
|
|
|
|
|
with record_function("(zhg) embed idx -> cache chunk id"):
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
|
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
|
|
|
|
|
|
|
|
|
# update for LFU.
|
|
|
|
@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self._cuda_to_cpu_elapse = 0
|
|
|
|
|
self._cuda_to_cpu_numel = 0
|
|
|
|
|
|
|
|
|
|
def _chunk_in_cuda(self, chunk_id: int) -> bool:
|
|
|
|
|
return self.inverted_cached_idx[chunk_id] != -1
|
|
|
|
|
def _row_in_cuda(self, row_id: int) -> bool:
|
|
|
|
|
return self.inverted_cached_idx[row_id] != -1
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
|
|
|
|
"""prepare rows in cpu_row_idxs on CUDA memory
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
|
|
|
|
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
|
|
|
|
|
"""
|
|
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num
|
|
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
|
|
|
|
|
if evict_num > 0:
|
|
|
|
|
with Timer() as timer:
|
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
|
|
|
@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
deprecated
|
|
|
|
|
|
|
|
|
|
evict one chunk from cuda to cpu.
|
|
|
|
|
evict one row from cuda to cpu.
|
|
|
|
|
Returns:
|
|
|
|
|
(int) : the slot id be evicted.
|
|
|
|
|
"""
|
|
|
|
|