mirror of https://github.com/hpcaitech/ColossalAI
[embedding] polish async copy (#1657)
parent
988570e4a6
commit
c638bec028
|
@ -15,6 +15,23 @@ class EvictionStrategy(Enum):
|
||||||
DATASET = 2
|
DATASET = 2
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None:
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
|
# As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
|
||||||
|
# PyTorch uses the "caching allocator" for memroy allocation for tensors. When a tensor is
|
||||||
|
# freed, its memory is likely to be reused by newly constructed tenosrs. By default,
|
||||||
|
# this allocator traces whether a tensor is still in use by only the CUDA stream where it
|
||||||
|
# was created. When a tensor is used by additional CUDA streams, we need to call record_stream
|
||||||
|
# to tell the allocator about all these streams. Otherwise, the allocator might free the
|
||||||
|
# underlying memory of the tensor once it is no longer used by the creator stream. This is
|
||||||
|
# a notable programming trick when we write programs using multi CUDA streams.
|
||||||
|
cur_stream = torch.cuda.current_stream()
|
||||||
|
assert isinstance(t, torch.Tensor)
|
||||||
|
t.record_stream(cur_stream)
|
||||||
|
|
||||||
|
|
||||||
class CachedParamMgr(torch.nn.Module):
|
class CachedParamMgr(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||||
|
@ -37,7 +54,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
cuda_row_num: int = 0,
|
cuda_row_num: int = 0,
|
||||||
buffer_size: int = 0,
|
buffer_size: int = 0,
|
||||||
pin_weight: bool = False,
|
pin_weight: bool = True,
|
||||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||||
async_copy: bool = False,
|
async_copy: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -62,6 +79,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self._async_copy = async_copy
|
self._async_copy = async_copy
|
||||||
|
|
||||||
if self._async_copy:
|
if self._async_copy:
|
||||||
|
self._memcpy_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
print('use async copy')
|
print('use async copy')
|
||||||
|
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
@ -350,11 +369,10 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# move evict in rows to gpu
|
# move evict in rows to gpu
|
||||||
if self._async_copy:
|
if self._async_copy:
|
||||||
if self.buffer_size == 0:
|
if self.buffer_size == 0:
|
||||||
idxslt_stream = torch.cuda.Stream()
|
evict_in_rows_gpu = self.weight.view(self.num_embeddings,
|
||||||
with torch.cuda.stream(idxslt_stream):
|
-1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
||||||
rows_cpu = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs_copy).pin_memory()
|
with torch.cuda.stream(self._memcpy_stream):
|
||||||
# evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
evict_in_rows_gpu = evict_in_rows_gpu.to(torch.cuda.current_device(), non_blocking=True)
|
||||||
# evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
|
||||||
|
@ -378,7 +396,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
-1).index_select(0, evict_gpu_row_idxs)
|
-1).index_select(0, evict_gpu_row_idxs)
|
||||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
with torch.cuda.stream(None):
|
||||||
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
@ -393,7 +412,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
evict_out_rows_gpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
-1).index_select(0, evict_gpu_row_idxs)
|
-1).index_select(0, evict_gpu_row_idxs)
|
||||||
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
evict_out_rows_cpu = torch.empty_like(evict_out_rows_gpu, device='cpu', pin_memory=True)
|
||||||
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
with torch.cuda.stream(None):
|
||||||
|
evict_out_rows_cpu.copy_(evict_out_rows_gpu, non_blocking=True)
|
||||||
with self.timer("3_1_2_find_evict_index_copy") as timer:
|
with self.timer("3_1_2_find_evict_index_copy") as timer:
|
||||||
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
self.freq_cnter.index_copy_(0, invalid_idxs, backup_freqs)
|
||||||
|
|
||||||
|
@ -410,7 +430,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
# allocate tmp memory on CPU and copy rows on CUDA to CPU.
|
||||||
# TODO async gpu -> cpu
|
# TODO async gpu -> cpu
|
||||||
if self._async_copy:
|
if self._async_copy:
|
||||||
pass
|
_wait_for_data(evict_out_rows_cpu, None)
|
||||||
else:
|
else:
|
||||||
with self.timer("3_2_1_evict_out_index_select") as timer:
|
with self.timer("3_2_1_evict_out_index_select") as timer:
|
||||||
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
evict_out_rows_cpu = self.cuda_cached_weight.view(self.cuda_row_num,
|
||||||
|
@ -445,10 +465,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
|
||||||
else:
|
else:
|
||||||
if self._async_copy:
|
if self._async_copy:
|
||||||
torch.cuda.current_stream().wait_stream(idxslt_stream)
|
_wait_for_data(evict_in_rows_gpu, self._memcpy_stream)
|
||||||
evict_in_rows_gpu = torch.empty_like(rows_cpu, device=torch.cuda.current_device())
|
|
||||||
evict_in_rows_gpu.copy_(rows_cpu, non_blocking=True)
|
|
||||||
pass
|
|
||||||
else:
|
else:
|
||||||
with self.timer("3_4_1_evict_in_index_select") as timer:
|
with self.timer("3_4_1_evict_in_index_select") as timer:
|
||||||
# narrow index select to a subset of self.weight
|
# narrow index select to a subset of self.weight
|
||||||
|
|
|
@ -66,6 +66,9 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
self._preprocess(_weight, cuda_row_num, ids_freq_mapping, warmup_ratio, buffer_size, pin_weight)
|
||||||
self.cache_op = True
|
self.cache_op = True
|
||||||
|
|
||||||
|
def set_cache_mgr_async_copy(self, flag):
|
||||||
|
self.cache_weight_mgr._async_copy = flag
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim, dtype=dtype, device=device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
@ -114,7 +114,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
cuda_row_num: int = 100_000,
|
cuda_row_num: int = 100_000,
|
||||||
ids_freq_mapping: Optional[List[int]] = None,
|
ids_freq_mapping: Optional[List[int]] = None,
|
||||||
warmup_ratio: float = 0.7,
|
warmup_ratio: float = 0.7,
|
||||||
buffer_size: int = 50_000,
|
buffer_size: int = 0,
|
||||||
) -> 'ParallelFreqAwareEmbeddingBag':
|
) -> 'ParallelFreqAwareEmbeddingBag':
|
||||||
rows, cols = embedding.shape
|
rows, cols = embedding.shape
|
||||||
embedding_bag = cls(rows,
|
embedding_bag = cls(rows,
|
||||||
|
|
Loading…
Reference in New Issue