[FAW] remove code related to chunk (#1501)

pull/1504/head
Jiarui Fang 2 years ago committed by GitHub
parent d5085bb317
commit ba61109b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
self.num_hits_history = []
self.num_miss_history = []
self.num_write_back_history = []
self.input_id_percent_in_load_chunk = []
self._reset_comm_stats()
self._evict_strategy = evict_strategy
@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module):
# self.cuda_cached_weight = self.weight
raise NotImplementedError()
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor:
def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
"""
access a chunk of CPU weight.
access a row of CPU weight.
Args:
chunk_id (int): chunk id
row_idx (int): the idx of rows
Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D.
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
"""
return self.weight.data.view(-1).narrow(0,
int(chunk_id) * self.embedding_dim,
int(row_idx) * self.embedding_dim,
self.embedding_dim).view(1, self.embedding_dim)
@property
def cuda_available_chunk_num(self):
def cuda_available_row_num(self):
return self._cuda_available_row_num
@torch.no_grad()
@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module):
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0:
with Timer() as timer:
# extract chunks from cpu weight
# extract rows from cpu weight
preload_row_ids = torch.arange(preload_row_num)
preload_slot_ids = preload_row_ids.cuda()
@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module):
src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else:
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks)
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows)
# update auxiliary info
slot_offsets = preload_slot_ids
@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module):
print(f'Cache warmup finished cost {timer.elapsed} sec.')
def flush(self):
"""flush all CUDA chunks to CPU.
"""flush all CUDA rows to CPU.
The function is usually called after training finished.
"""
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
chunk_ids = self.cached_idx_map[slots]
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks)
row_ids = self.cached_idx_map[slots]
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1)
self.inverted_cached_idx.index_fill_(0, row_ids, -1)
self._cuda_available_row_num += slots.numel()
assert self._cuda_available_row_num == self.cuda_row_num
@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \
f"please increase cuda_row_num shrink batch size"
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs
with record_function("(zhg) get cpu chunk indices"):
with record_function("(zhg) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0)
# move sure the cuda chunk will not be evicted!
# move sure the cuda rows will not be evicted!
with record_function("(zhg) cache update"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"):
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_to_cpu_elapse = 0
self._cuda_to_cpu_numel = 0
def _chunk_in_cuda(self, chunk_id: int) -> bool:
return self.inverted_cached_idx[chunk_id] != -1
def _row_in_cuda(self, row_id: int) -> bool:
return self.inverted_cached_idx[row_id] != -1
@torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory
Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
"""
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module):
"""
deprecated
evict one chunk from cuda to cpu.
evict one row from cuda to cpu.
Returns:
(int) : the slot id be evicted.
"""

@ -119,8 +119,4 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse
return 0
@property
def input_id_percent_in_load_chunk(self):
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
return 0
Loading…
Cancel
Save