[embedding] polish async copy (#1657)

pull/1662/head
Jiarui Fang 2022-09-27 14:37:03 +08:00 committed by GitHub
parent 988570e4a6
commit c638bec028
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 14 deletions

View File

@ -15,6 +15,23 @@ class EvictionStrategy(Enum):
DATASET = 2 DATASET = 2
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
if stream is None:
return
torch.cuda.current_stream().wait_stream(stream)
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
# to tell the allocator about all these streams. Otherwise, the allocator might free the
# underlying memory of the tensor once it is no longer used by the creator stream. This is
# a notable programming trick when we write programs using multi CUDA streams.
cur_stream = torch.cuda.current_stream()
assert isinstance(t, torch.Tensor)
t.record_stream(cur_stream)
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
""" """
Manage Embedding Weights on CPU and CUDA memory uses a software cache. Manage Embedding Weights on CPU and CUDA memory uses a software cache.
@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module):
weight: torch.Tensor, weight: torch.Tensor,
cuda_row_num: int = 0, cuda_row_num: int = 0,
buffer_size: int = 0, buffer_size: int = 0,
pin_weight: bool = False, pin_weight: bool = True,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
async_copy: bool = False, async_copy: bool = False,
) -> None: ) -> None:
@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module):
self._async_copy = async_copy self._async_copy = async_copy
if self._async_copy: if self._async_copy:
self._memcpy_stream = torch.cuda.Stream()
print('use async copy') print('use async copy')
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module):
# move evict in rows to gpu # move evict in rows to gpu
if self._async_copy: if self._async_copy:
if self.buffer_size == 0: if self.buffer_size == 0:
idxslt_stream = torch.cuda.Stream() evict_in_rows_gpu = self.weight.view(self.num_embeddings,
with torch.cuda.stream(idxslt_stream): -1).index_select(0, cpu_row_idxs_copy).pin_memory()
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory() with torch.cuda.stream(self._memcpy_stream):
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device()) evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
else: else:
raise NotImplemented raise NotImplemented
@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module):
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs) -1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU: elif self._evict_strategy == EvictionStrategy.LFU:
@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module):
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num, evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
-1).index_select(0, evict_gpu_row_idxs) -1).index_select(0, evict_gpu_row_idxs)
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True) evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True) with torch.cuda.stream(None):
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
with self.timer("3_1_2_find_evict_index_copy") as timer: with self.timer("3_1_2_find_evict_index_copy") as timer:
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs) self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module):
# allocate tmp memory on CPU and copy rows on CUDA to CPU. # allocate tmp memory on CPU and copy rows on CUDA to CPU.
# TODO async gpu -> cpu # TODO async gpu -> cpu
if self._async_copy: if self._async_copy:
pass _wait_for_data(evict_out_rows_cpu, None)
else: else:
with self.timer("3_2_1_evict_out_index_select") as timer: with self.timer("3_2_1_evict_out_index_select") as timer:
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num, evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module):
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
if self._async_copy: if self._async_copy:
torch.cuda.current_stream().wait_stream(idxslt_stream) _wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
pass
else: else:
with self.timer("3_4_1_evict_in_index_select") as timer: with self.timer("3_4_1_evict_in_index_select") as timer:
# narrow index select to a subset of self.weight # narrow index select to a subset of self.weight

View File

@ -66,6 +66,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight) self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
self.cache_op = True self.cache_op = True
def set_cache_mgr_async_copy(self, flag):
self.cache_weight_mgr._async_copy = flag
def _weight_alloc(self, dtype, device): def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device) weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
with torch.no_grad(): with torch.no_grad():

View File

@ -114,7 +114,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
cuda_row_num: int = 100_000, cuda_row_num: int = 100_000,
ids_freq_mapping: Optional[List[int]] = None, ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio: float = 0.7, warmup_ratio: float = 0.7,
buffer_size: int = 50_000, buffer_size: int = 0,
) -> 'ParallelFreqAwareEmbeddingBag': ) -> 'ParallelFreqAwareEmbeddingBag':
rows, cols = embedding.shape rows, cols = embedding.shape
embedding_bag = cls(rows, embedding_bag = cls(rows,