|
|
@ -6,6 +6,7 @@ from contexttimer import Timer
|
|
|
|
from .copyer import LimitBuffIndexCopyer
|
|
|
|
from .copyer import LimitBuffIndexCopyer
|
|
|
|
from enum import Enum
|
|
|
|
from enum import Enum
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EvictionStrategy(Enum):
|
|
|
|
class EvictionStrategy(Enum):
|
|
|
@ -55,7 +56,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
self.num_hits_history = []
|
|
|
|
self.num_hits_history = []
|
|
|
|
self.num_miss_history = []
|
|
|
|
self.num_miss_history = []
|
|
|
|
self.num_write_back_history = []
|
|
|
|
self.num_write_back_history = []
|
|
|
|
self._reset_comm_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._evict_strategy = evict_strategy
|
|
|
|
self._evict_strategy = evict_strategy
|
|
|
|
|
|
|
|
|
|
|
@ -71,6 +71,25 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
|
|
|
torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
|
|
|
|
dtype=torch.long).fill_(sys.maxsize),
|
|
|
|
dtype=torch.long).fill_(sys.maxsize),
|
|
|
|
persistent=False)
|
|
|
|
persistent=False)
|
|
|
|
|
|
|
|
self._elapsed_dict = {}
|
|
|
|
|
|
|
|
self._reset_comm_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _reset_comm_stats(self):
|
|
|
|
|
|
|
|
for k in self._elapsed_dict.keys():
|
|
|
|
|
|
|
|
self._elapsed_dict[k] = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._cpu_to_cuda_numel = 0
|
|
|
|
|
|
|
|
self._cuda_to_cpu_numel = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
|
|
|
|
def timer(self, name):
|
|
|
|
|
|
|
|
with Timer() as t:
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if name not in self._elapsed_dict.keys():
|
|
|
|
|
|
|
|
self._elapsed_dict[name] = 0
|
|
|
|
|
|
|
|
self._elapsed_dict[name] += t.elapsed
|
|
|
|
|
|
|
|
|
|
|
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
|
|
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
|
|
|
"""_find_evict_gpu_idxs
|
|
|
|
"""_find_evict_gpu_idxs
|
|
|
@ -193,7 +212,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
preload_cpu_ids = torch.arange(preload_row_num)
|
|
|
|
preload_cpu_ids = torch.arange(preload_row_num)
|
|
|
|
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
|
|
|
preload_cuda_row_idxs = preload_cpu_ids.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
src_index=preload_cpu_ids,
|
|
|
|
src_index=preload_cpu_ids,
|
|
|
@ -239,13 +257,20 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def print_comm_stats(self):
|
|
|
|
def print_comm_stats(self):
|
|
|
|
if self._cuda_to_cpu_numel > 0:
|
|
|
|
if self._cuda_to_cpu_numel > 0:
|
|
|
|
|
|
|
|
elapsed = self._elapsed_dict["3_2_2_evict_out_gpu_to_cpu_copy"]
|
|
|
|
print(
|
|
|
|
print(
|
|
|
|
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / self._cuda_to_cpu_elapse} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
|
|
|
|
f"CUDA->CPU BWD {self._cuda_to_cpu_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cuda_to_cpu_numel / 1e6} M elem"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
print(f'cuda_to_cpu_elapse {elapsed} sec')
|
|
|
|
if self._cpu_to_cuda_numel > 0:
|
|
|
|
if self._cpu_to_cuda_numel > 0:
|
|
|
|
|
|
|
|
elapsed = self._elapsed_dict["3_4_2_evict_in_gpu_to_cpu_copy"]
|
|
|
|
print(
|
|
|
|
print(
|
|
|
|
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / self._cpu_to_cuda_elpase} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
|
|
|
|
f"CPU->CUDA BWD {self._cpu_to_cuda_numel * self.elem_size_in_byte / 1e6 / elapsed} MB/s {self._cpu_to_cuda_numel / 1e6} M elem"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
print(f'cpu_to_cuda_elpase {elapsed} sec')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in self._elapsed_dict.items():
|
|
|
|
|
|
|
|
print(f'{k}: {v}')
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
def _id_to_cached_cuda_id(self, ids: torch.Tensor) -> torch.Tensor:
|
|
|
@ -270,7 +295,9 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
torch.Tensor: indices on the cuda_cached_weight.
|
|
|
|
torch.Tensor: indices on the cuda_cached_weight.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
with record_function("(zhg) get unique indices"):
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
with self.timer("1_unique_indices") as timer:
|
|
|
|
|
|
|
|
with record_function("(cache) get unique indices"):
|
|
|
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
|
|
|
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
|
|
|
|
|
|
|
|
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
|
|
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
|
|
@ -278,8 +305,11 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
|
|
|
|
f"It is larger than the capacity of the cache, which at most contains {self.cuda_row_num} rows, " \
|
|
|
|
f"Please increase cuda_row_num or decrease the training batch size."
|
|
|
|
f"Please increase cuda_row_num or decrease the training batch size."
|
|
|
|
self.evict_backlist = cpu_row_idxs
|
|
|
|
self.evict_backlist = cpu_row_idxs
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) get cpu row idxs"):
|
|
|
|
# O(cache ratio)
|
|
|
|
|
|
|
|
with self.timer("2_cpu_row_idx") as timer:
|
|
|
|
|
|
|
|
with record_function("(cache) get cpu row idxs"):
|
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
|
|
|
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
|
|
|
|
|
|
|
|
|
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
|
|
|
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
|
|
|
@ -287,12 +317,14 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
self.num_write_back_history.append(0)
|
|
|
|
self.num_write_back_history.append(0)
|
|
|
|
|
|
|
|
|
|
|
|
# move sure the cuda rows will not be evicted!
|
|
|
|
# move sure the cuda rows will not be evicted!
|
|
|
|
with record_function("(zhg) cache update"):
|
|
|
|
with self.timer("3_prepare_rows_on_cuda") as timer:
|
|
|
|
|
|
|
|
with record_function("(cache) prepare_rows_on_cuda"):
|
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
|
|
|
self._prepare_rows_on_cuda(comm_cpu_row_idxs)
|
|
|
|
|
|
|
|
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
|
|
|
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
|
|
|
|
with self.timer("4_cpu_to_gpu_row_idxs") as timer:
|
|
|
|
|
|
|
|
with record_function("(cache) embed cpu rows idx -> cache gpu row idxs"):
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
|
|
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
|
|
|
|
|
|
|
|
|
|
|
# update for LFU.
|
|
|
|
# update for LFU.
|
|
|
@ -302,12 +334,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
return gpu_row_idxs
|
|
|
|
return gpu_row_idxs
|
|
|
|
|
|
|
|
|
|
|
|
def _reset_comm_stats(self):
|
|
|
|
|
|
|
|
self._cpu_to_cuda_numel = 0
|
|
|
|
|
|
|
|
self._cpu_to_cuda_elpase = 0
|
|
|
|
|
|
|
|
self._cuda_to_cpu_elapse = 0
|
|
|
|
|
|
|
|
self._cuda_to_cpu_numel = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _row_in_cuda(self, row_id: int) -> bool:
|
|
|
|
def _row_in_cuda(self, row_id: int) -> bool:
|
|
|
|
return self.inverted_cached_idx[row_id] != -1
|
|
|
|
return self.inverted_cached_idx[row_id] != -1
|
|
|
|
|
|
|
|
|
|
|
@ -324,14 +350,17 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
# move evict in rows to gpu
|
|
|
|
# move evict in rows to gpu
|
|
|
|
if self._async_copy:
|
|
|
|
if self._async_copy:
|
|
|
|
if self.buffer_size == 0:
|
|
|
|
if self.buffer_size == 0:
|
|
|
|
|
|
|
|
idxslt_stream = torch.cuda.Stream()
|
|
|
|
|
|
|
|
with torch.cuda.stream(idxslt_stream):
|
|
|
|
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
|
|
|
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
|
|
|
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
|
|
|
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
|
|
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
raise NotImplemented
|
|
|
|
raise NotImplemented
|
|
|
|
|
|
|
|
|
|
|
|
if evict_num > 0:
|
|
|
|
if evict_num > 0:
|
|
|
|
with Timer() as timer:
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
with self.timer("3_1_evict_prepare") as timer:
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
|
|
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
|
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
|
|
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
|
|
@ -340,6 +369,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
# so those idxs will be sorted to end, therefore not being chosen as victim
|
|
|
|
# so those idxs will be sorted to end, therefore not being chosen as victim
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
|
|
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
|
|
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer:
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
|
|
|
|
|
|
|
|
# move evict out rows to cpu
|
|
|
|
# move evict out rows to cpu
|
|
|
@ -353,7 +384,10 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
|
|
|
backup_freqs = self.freq_cnter[invalid_idxs].clone()
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
|
|
|
self.freq_cnter.index_fill_(0, invalid_idxs, sys.maxsize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_1_1_find_evict_gpu_idxs_elapsed") as timer:
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
|
|
|
|
|
|
|
|
|
|
|
if self._async_copy:
|
|
|
|
if self._async_copy:
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs)
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs)
|
|
|
@ -363,6 +397,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
|
|
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_2_evict_out_elapse") as timer:
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
src_index=evict_gpu_row_idxs,
|
|
|
|
src_index=evict_gpu_row_idxs,
|
|
|
@ -375,8 +410,12 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
if self._async_copy:
|
|
|
|
if self._async_copy:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
with self.timer("3_2_1_evict_out_index_select") as timer:
|
|
|
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
|
|
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs).cpu()
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs)
|
|
|
|
|
|
|
|
with self.timer("3_2_2_evict_out_gpu_to_cpu_copy") as timer:
|
|
|
|
|
|
|
|
evict_out_rows_cpu = evict_out_rows_cpu.cpu()
|
|
|
|
|
|
|
|
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
|
|
|
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), evict_out_rows_cpu)
|
|
|
|
|
|
|
|
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
|
|
|
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
|
|
|
@ -385,12 +424,15 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
self._cuda_available_row_num += evict_num
|
|
|
|
self._cuda_available_row_num += evict_num
|
|
|
|
|
|
|
|
|
|
|
|
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
weight_size = evict_gpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
self._cuda_to_cpu_elapse += timer.elapsed
|
|
|
|
|
|
|
|
self._cuda_to_cpu_numel += weight_size
|
|
|
|
self._cuda_to_cpu_numel += weight_size
|
|
|
|
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
|
|
|
# print(f"evict embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
|
|
|
|
|
|
|
|
|
|
|
with Timer() as timer:
|
|
|
|
# slots of cuda weight to evict in
|
|
|
|
|
|
|
|
with self.timer("3_3_non_zero") as timer:
|
|
|
|
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
|
|
|
|
slots = torch.nonzero(self.cached_idx_map == -1).squeeze(1)[:cpu_row_idxs.numel()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO wait for optimize
|
|
|
|
|
|
|
|
with self.timer("3_4_evict_in_elapse") as timer:
|
|
|
|
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
|
|
|
# Here also allocate extra memory on CUDA. #cpu_row_idxs
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
if self.buffer_size > 0:
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
|
self.limit_buff_index_copyer.index_copy(0,
|
|
|
@ -399,22 +441,34 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
src=self.weight.view(self.num_embeddings, -1),
|
|
|
|
src=self.weight.view(self.num_embeddings, -1),
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# TODO async copy cpu -> gpu
|
|
|
|
|
|
|
|
if self._async_copy:
|
|
|
|
if self._async_copy:
|
|
|
|
|
|
|
|
torch.cuda.current_stream().wait_stream(idxslt_stream)
|
|
|
|
|
|
|
|
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
|
|
|
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
|
|
|
# TODO hotspot: index select copy cpu -> gpu, cpu index?
|
|
|
|
-1).index_select(0, cpu_row_idxs_copy).cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_1_evict_in_index_select") as timer:
|
|
|
|
|
|
|
|
# narrow index select to a subset of self.weight
|
|
|
|
|
|
|
|
# tmp = torch.narrow(self.weight.view(self.num_embeddings, -1), 0, min(cpu_row_idxs).cpu(), max(cpu_row_idxs) - min(cpu_row_idxs) + 1)
|
|
|
|
|
|
|
|
# evict_in_rows_gpu = tmp.index_select(0, cpu_row_idxs_copy - min(cpu_row_idxs).cpu())
|
|
|
|
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_2_evict_in_gpu_to_cpu_copy") as timer:
|
|
|
|
|
|
|
|
evict_in_rows_gpu = evict_in_rows_gpu.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_3_evict_in_index_copy") as timer:
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
|
|
|
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, evict_in_rows_gpu)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.timer("3_4_evict_in_elapse") as timer:
|
|
|
|
slot_offsets = slots
|
|
|
|
slot_offsets = slots
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs
|
|
|
|
self.cached_idx_map[slots] = cpu_row_idxs
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
|
|
|
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU:
|
|
|
|
self.freq_cnter.index_fill_(0, slots, 0)
|
|
|
|
self.freq_cnter.index_fill_(0, slots, 0)
|
|
|
|
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
|
|
|
self._cuda_available_row_num -= cpu_row_idxs.numel()
|
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed
|
|
|
|
|
|
|
|
weight_size = cpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
weight_size = cpu_row_idxs.numel() * self.embedding_dim
|
|
|
|
self._cpu_to_cuda_numel += weight_size
|
|
|
|
self._cpu_to_cuda_numel += weight_size
|
|
|
|
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
|
|
|
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
|
|
@ -461,7 +515,6 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
self._cuda_available_row_num += 1
|
|
|
|
self._cuda_available_row_num += 1
|
|
|
|
|
|
|
|
|
|
|
|
self._cuda_to_cpu_numel += self.embedding_dim
|
|
|
|
self._cuda_to_cpu_numel += self.embedding_dim
|
|
|
|
self._cuda_to_cpu_elapse += timer.elapsed
|
|
|
|
|
|
|
|
# self.num_write_back_history[-1] += 1
|
|
|
|
# self.num_write_back_history[-1] += 1
|
|
|
|
return max_cpu_row_idx
|
|
|
|
return max_cpu_row_idx
|
|
|
|
|
|
|
|
|
|
|
@ -495,4 +548,3 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
self._cuda_available_row_num -= 1
|
|
|
|
self._cuda_available_row_num -= 1
|
|
|
|
|
|
|
|
|
|
|
|
self._cpu_to_cuda_numel += self.embedding_dim
|
|
|
|
self._cpu_to_cuda_numel += self.embedding_dim
|
|
|
|
self._cpu_to_cuda_elpase += timer.elapsed
|
|
|
|
|
|
|
|