|
|
@ -56,7 +56,6 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
self.num_hits_history = [] |
|
|
|
self.num_hits_history = [] |
|
|
|
self.num_miss_history = [] |
|
|
|
self.num_miss_history = [] |
|
|
|
self.num_write_back_history = [] |
|
|
|
self.num_write_back_history = [] |
|
|
|
self.input_id_percent_in_load_chunk = [] |
|
|
|
|
|
|
|
self._reset_comm_stats() |
|
|
|
self._reset_comm_stats() |
|
|
|
|
|
|
|
|
|
|
|
self._evict_strategy = evict_strategy |
|
|
|
self._evict_strategy = evict_strategy |
|
|
@ -156,23 +155,23 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
# self.cuda_cached_weight = self.weight |
|
|
|
# self.cuda_cached_weight = self.weight |
|
|
|
raise NotImplementedError() |
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
def cpu_weight_data(self, chunk_id: int) -> torch.Tensor: |
|
|
|
def cpu_weight_data(self, row_idx: int) -> torch.Tensor: |
|
|
|
""" |
|
|
|
""" |
|
|
|
access a chunk of CPU weight. |
|
|
|
access a row of CPU weight. |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
chunk_id (int): chunk id |
|
|
|
row_idx (int): the idx of rows |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
torch.Tensor: a piece of memory in CPU weight corresponding to chunk id's payload. The tensor is 1-D. |
|
|
|
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
return self.weight.data.view(-1).narrow(0, |
|
|
|
return self.weight.data.view(-1).narrow(0, |
|
|
|
int(chunk_id) * self.embedding_dim, |
|
|
|
int(row_idx) * self.embedding_dim, |
|
|
|
self.embedding_dim).view(1, self.embedding_dim) |
|
|
|
self.embedding_dim).view(1, self.embedding_dim) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
@property |
|
|
|
def cuda_available_chunk_num(self): |
|
|
|
def cuda_available_row_num(self): |
|
|
|
return self._cuda_available_row_num |
|
|
|
return self._cuda_available_row_num |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
@torch.no_grad() |
|
|
@ -202,7 +201,7 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) |
|
|
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) |
|
|
|
if preload_row_num > 0: |
|
|
|
if preload_row_num > 0: |
|
|
|
with Timer() as timer: |
|
|
|
with Timer() as timer: |
|
|
|
# extract chunks from cpu weight |
|
|
|
# extract rows from cpu weight |
|
|
|
preload_row_ids = torch.arange(preload_row_num) |
|
|
|
preload_row_ids = torch.arange(preload_row_num) |
|
|
|
preload_slot_ids = preload_row_ids.cuda() |
|
|
|
preload_slot_ids = preload_row_ids.cuda() |
|
|
|
|
|
|
|
|
|
|
@ -213,8 +212,8 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
src=self.weight.view(self.num_embeddings, -1), |
|
|
|
src=self.weight.view(self.num_embeddings, -1), |
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) |
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) |
|
|
|
else: |
|
|
|
else: |
|
|
|
preload_chunks = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() |
|
|
|
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_row_ids).cuda() |
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_chunks) |
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_slot_ids, preload_rows) |
|
|
|
|
|
|
|
|
|
|
|
# update auxiliary info |
|
|
|
# update auxiliary info |
|
|
|
slot_offsets = preload_slot_ids |
|
|
|
slot_offsets = preload_slot_ids |
|
|
@ -224,15 +223,15 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
print(f'Cache warmup finished cost {timer.elapsed} sec.') |
|
|
|
print(f'Cache warmup finished cost {timer.elapsed} sec.') |
|
|
|
|
|
|
|
|
|
|
|
def flush(self): |
|
|
|
def flush(self): |
|
|
|
"""flush all CUDA chunks to CPU. |
|
|
|
"""flush all CUDA rows to CPU. |
|
|
|
The function is usually called after training finished. |
|
|
|
The function is usually called after training finished. |
|
|
|
""" |
|
|
|
""" |
|
|
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) |
|
|
|
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) |
|
|
|
chunk_ids = self.cached_idx_map[slots] |
|
|
|
row_ids = self.cached_idx_map[slots] |
|
|
|
chunks = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() |
|
|
|
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu() |
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, chunk_ids.cpu(), chunks) |
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) |
|
|
|
self.cached_idx_map.index_fill_(0, slots, -1) |
|
|
|
self.cached_idx_map.index_fill_(0, slots, -1) |
|
|
|
self.inverted_cached_idx.index_fill_(0, chunk_ids, -1) |
|
|
|
self.inverted_cached_idx.index_fill_(0, row_ids, -1) |
|
|
|
self._cuda_available_row_num += slots.numel() |
|
|
|
self._cuda_available_row_num += slots.numel() |
|
|
|
|
|
|
|
|
|
|
|
assert self._cuda_available_row_num == self.cuda_row_num |
|
|
|
assert self._cuda_available_row_num == self.cuda_row_num |
|
|
@ -280,25 +279,25 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original) |
|
|
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original) |
|
|
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \ |
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \ |
|
|
|
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ |
|
|
|
f"You move {len(cpu_row_idxs)} embedding rows from CPU to CUDA. " \ |
|
|
|
f"which is larger than the presented {self.cuda_row_num}, " \ |
|
|
|
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \ |
|
|
|
f"please increase cuda_row_num shrink batch size" |
|
|
|
f"Please increase cuda_row_num or decrease the training batch size." |
|
|
|
self.evict_backlist = cpu_row_idxs |
|
|
|
self.evict_backlist = cpu_row_idxs |
|
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) get cpu chunk indices"): |
|
|
|
with record_function("(zhg) get cpu row idxs"): |
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] |
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)] |
|
|
|
|
|
|
|
|
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) |
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) |
|
|
|
self.num_miss_history.append(len(comm_cpu_row_idxs)) |
|
|
|
self.num_miss_history.append(len(comm_cpu_row_idxs)) |
|
|
|
self.num_write_back_history.append(0) |
|
|
|
self.num_write_back_history.append(0) |
|
|
|
|
|
|
|
|
|
|
|
# move sure the cuda chunk will not be evicted! |
|
|
|
# move sure the cuda rows will not be evicted! |
|
|
|
with record_function("(zhg) cache update"): |
|
|
|
with record_function("(zhg) cache update"): |
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs) |
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs) |
|
|
|
|
|
|
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) |
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype) |
|
|
|
# new ids chunk_offset + offset_in_chunk |
|
|
|
|
|
|
|
with record_function("(zhg) embed idx -> cache chunk id"): |
|
|
|
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"): |
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids) |
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids) |
|
|
|
|
|
|
|
|
|
|
|
# update for LFU. |
|
|
|
# update for LFU. |
|
|
@ -311,17 +310,17 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
self._cuda_to_cpu_elapse = 0 |
|
|
|
self._cuda_to_cpu_elapse = 0 |
|
|
|
self._cuda_to_cpu_numel = 0 |
|
|
|
self._cuda_to_cpu_numel = 0 |
|
|
|
|
|
|
|
|
|
|
|
def _chunk_in_cuda(self, chunk_id: int) -> bool: |
|
|
|
def _row_in_cuda(self, row_id: int) -> bool: |
|
|
|
return self.inverted_cached_idx[chunk_id] != -1 |
|
|
|
return self.inverted_cached_idx[row_id] != -1 |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
@torch.no_grad() |
|
|
|
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: |
|
|
|
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: |
|
|
|
"""prepare rows in cpu_row_idxs on CUDA memory |
|
|
|
"""prepare rows in cpu_row_idxs on CUDA memory |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
Args: |
|
|
|
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA |
|
|
|
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA |
|
|
|
""" |
|
|
|
""" |
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_chunk_num |
|
|
|
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num |
|
|
|
if evict_num > 0: |
|
|
|
if evict_num > 0: |
|
|
|
with Timer() as timer: |
|
|
|
with Timer() as timer: |
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) |
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) |
|
|
@ -396,7 +395,7 @@ class CachedParamMgr(torch.nn.Module): |
|
|
|
""" |
|
|
|
""" |
|
|
|
deprecated |
|
|
|
deprecated |
|
|
|
|
|
|
|
|
|
|
|
evict one chunk from cuda to cpu. |
|
|
|
evict one row from cuda to cpu. |
|
|
|
Returns: |
|
|
|
Returns: |
|
|
|
(int) : the slot id be evicted. |
|
|
|
(int) : the slot id be evicted. |
|
|
|
""" |
|
|
|
""" |
|
|
|