mirror of https://github.com/hpcaitech/ColossalAI
[embedding] polish async copy (#1657)
parent
988570e4a6
commit
c638bec028
|
@ -15,6 +15,23 @@ class EvictionStrategy(Enum):
|
|||
DATASET = 2
|
||||
|
||||
|
||||
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
|
||||
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is
|
||||
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
|
||||
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
|
||||
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
|
||||
# to tell the allocator about all these streams. Otherwise, the allocator might free the
|
||||
# underlying memory of the tensor once it is no longer used by the creator stream. This is
|
||||
# a notable programming trick when we write programs using multi CUDA streams.
|
||||
cur_stream = torch.cuda.current_stream()
|
||||
assert isinstance(t, torch.Tensor)
|
||||
t.record_stream(cur_stream)
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||
|
@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 0,
|
||||
pin_weight: bool = False,
|
||||
pin_weight: bool = True,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
async_copy: bool = False,
|
||||
) -> None:
|
||||
|
@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self._async_copy = async_copy
|
||||
|
||||
if self._async_copy:
|
||||
self._memcpy_stream = torch.cuda.Stream()
|
||||
|
||||
print('use async copy')
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
|
@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# move evict in rows to gpu
|
||||
if self._async_copy:
|
||||
if self.buffer_size == 0:
|
||||
idxslt_stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(idxslt_stream):
|
||||
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||
with torch.cuda.stream(self._memcpy_stream):
|
||||
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
|
@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
|
@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
-1).index_select(0, evict_gpu_row_idxs)
|
||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with torch.cuda.stream(None):
|
||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||
with self.timer("3_1_2_find_evict_index_copy") as timer:
|
||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||
|
||||
|
@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||
# TODO async gpu -> cpu
|
||||
if self._async_copy:
|
||||
pass
|
||||
_wait_for_data(evict_out_rows_cpu, None)
|
||||
else:
|
||||
with self.timer("3_2_1_evict_out_index_select") as timer:
|
||||
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||
|
@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||
else:
|
||||
if self._async_copy:
|
||||
torch.cuda.current_stream().wait_stream(idxslt_stream)
|
||||
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
||||
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
||||
pass
|
||||
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
|
||||
else:
|
||||
with self.timer("3_4_1_evict_in_index_select") as timer:
|
||||
# narrow index select to a subset of self.weight
|
||||
|
|
|
@ -66,6 +66,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
|||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||
self.cache_op = True
|
||||
|
||||
def set_cache_mgr_async_copy(self, flag):
|
||||
self.cache_weight_mgr._async_copy = flag
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -114,7 +114,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
cuda_row_num: int = 100_000,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio: float = 0.7,
|
||||
buffer_size: int = 50_000,
|
||||
buffer_size: int = 0,
|
||||
) -> 'ParallelFreqAwareEmbeddingBag':
|
||||
rows, cols = embedding.shape
|
||||
embedding_bag = cls(rows,
|
||||
|
|
Loading…
Reference in New Issue