|
|
|
@ -15,6 +15,23 @@ class EvictionStrategy(Enum):
|
|
|
|
|
DATASET = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: |
|
|
|
|
if stream is None: |
|
|
|
|
return |
|
|
|
|
torch.cuda.current_stream().wait_stream(stream) |
|
|
|
|
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html, |
|
|
|
|
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is |
|
|
|
|
# freed, its memory is likely to be reused by newly constructed tenosrs. By default, |
|
|
|
|
# this allocator traces whether a tensor is still in use by only the CUDA stream where it |
|
|
|
|
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream |
|
|
|
|
# to tell the allocator about all these streams. Otherwise, the allocator might free the |
|
|
|
|
# underlying memory of the tensor once it is no longer used by the creator stream. This is |
|
|
|
|
# a notable programming trick when we write programs using multi CUDA streams. |
|
|
|
|
cur_stream = torch.cuda.current_stream() |
|
|
|
|
assert isinstance(t, torch.Tensor) |
|
|
|
|
t.record_stream(cur_stream) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CachedParamMgr(torch.nn.Module): |
|
|
|
|
""" |
|
|
|
|
Manage Embedding Weights on CPU and CUDA memory uses a software cache. |
|
|
|
@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
weight: torch.Tensor, |
|
|
|
|
cuda_row_num: int = 0, |
|
|
|
|
buffer_size: int = 0, |
|
|
|
|
pin_weight: bool = False, |
|
|
|
|
pin_weight: bool = True, |
|
|
|
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, |
|
|
|
|
async_copy: bool = False, |
|
|
|
|
) -> None: |
|
|
|
@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
self._async_copy = async_copy |
|
|
|
|
|
|
|
|
|
if self._async_copy: |
|
|
|
|
self._memcpy_stream = torch.cuda.Stream() |
|
|
|
|
|
|
|
|
|
print('use async copy') |
|
|
|
|
|
|
|
|
|
if self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# move evict in rows to gpu |
|
|
|
|
if self._async_copy: |
|
|
|
|
if self.buffer_size == 0: |
|
|
|
|
idxslt_stream = torch.cuda.Stream() |
|
|
|
|
with torch.cuda.stream(idxslt_stream): |
|
|
|
|
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() |
|
|
|
|
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) |
|
|
|
|
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) |
|
|
|
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings, |
|
|
|
|
-1).index_select(0, cpu_row_idxs_copy).pin_memory() |
|
|
|
|
with torch.cuda.stream(self._memcpy_stream): |
|
|
|
|
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True) |
|
|
|
|
else: |
|
|
|
|
raise NotImplemented |
|
|
|
|
|
|
|
|
@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs) |
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) |
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
|
with torch.cuda.stream(None): |
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) |
|
|
|
|
|
|
|
|
|
elif self._evict_strategy == EvictionStrategy.LFU: |
|
|
|
@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
|
-1).index_select(0, evict_gpu_row_idxs) |
|
|
|
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) |
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
|
with torch.cuda.stream(None): |
|
|
|
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) |
|
|
|
|
with self.timer("3_1_2_find_evict_index_copy") as timer: |
|
|
|
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) |
|
|
|
|
|
|
|
|
@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU. |
|
|
|
|
# TODO async gpu -> cpu |
|
|
|
|
if self._async_copy: |
|
|
|
|
pass |
|
|
|
|
_wait_for_data(evict_out_rows_cpu, None) |
|
|
|
|
else: |
|
|
|
|
with self.timer("3_2_1_evict_out_index_select") as timer: |
|
|
|
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, |
|
|
|
@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module):
|
|
|
|
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) |
|
|
|
|
else: |
|
|
|
|
if self._async_copy: |
|
|
|
|
torch.cuda.current_stream().wait_stream(idxslt_stream) |
|
|
|
|
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) |
|
|
|
|
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True) |
|
|
|
|
pass |
|
|
|
|
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream) |
|
|
|
|
else: |
|
|
|
|
with self.timer("3_4_1_evict_in_index_select") as timer: |
|
|
|
|
# narrow index select to a subset of self.weight |
|
|
|
|