[embedding] rollback for better FAW performance (#1625)

pull/1630/head
Jiarui Fang 2022-09-22 11:16:25 +08:00 committed by GitHub
parent d925122020
commit 38c68b5b9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 50 additions and 88 deletions

View File

@ -20,15 +20,15 @@ class CachedParamMgr(torch.nn.Module):
CPU maintains the entire original weight. CPU maintains the entire original weight.
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit embedding rows between CPU and GPU. During training, GPU needs to transmit embedding rows between CPU and GPU.
Args: Args:
weight (torch.Tensor): the weight of the Embedding layer. weight (torch.Tensor): the weight of the Embedding layer.
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0. cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000. buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False. pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options. `EvictionStrategy.LFU` uses the least frequently used cache. `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume. evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
Default as EvictionStrategy.DATASET. `EvictionStrategy.LFU`: use the least frequently used cache.
use_cpu_caching (bool, optional): use cpu to execute cache indexing. It is slower than use gpu. `EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
Defaults to EvictionStrategy.DATASET.
""" """
def __init__( def __init__(
@ -38,7 +38,6 @@ class CachedParamMgr(torch.nn.Module):
buffer_size: int = 0, buffer_size: int = 0,
pin_weight: bool = False, pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
use_cpu_caching=False,
) -> None: ) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
@ -48,13 +47,6 @@ class CachedParamMgr(torch.nn.Module):
self.pin_weight = pin_weight self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size() self.elem_size_in_byte = weight.element_size()
self._cpu_caching = use_cpu_caching
if self._cpu_caching:
self._cache_dev = torch.device('cpu')
else:
self._cache_dev = torch.cuda.current_device()
# weight configure # weight configure
self._init_weight(weight) self._init_weight(weight)
@ -69,24 +61,16 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
# cache_row_idx -> frequency, freq of the cache rows. # cache_row_idx -> frequency, freq of the cache rows.
# classic lfu cache. evict the minimal freq value row in cuda cache. # classic lfu cache. evict the minimal freq value row in cuda cache.
if self._cpu_caching: self.register_buffer("freq_cnter",
self.freq_cnter = torch.empty(self.cuda_row_num, device=self._cache_dev, torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
dtype=torch.long).fill_(sys.maxsize) dtype=torch.long).fill_(sys.maxsize),
persistent=False)
else:
self.register_buffer("freq_cnter",
torch.empty(self.cuda_row_num, device=self._cache_dev,
dtype=torch.long).fill_(sys.maxsize),
persistent=False)
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs """_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq. Find the gpu idxs to be evicted, according to their freq.
Args: Args:
evict_num (int): how many rows has to be evicted evict_num (int): how many rows has to be evicted
Returns: Returns:
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs. torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
""" """
@ -117,32 +101,26 @@ class CachedParamMgr(torch.nn.Module):
self.weight = weight.pin_memory() if self.pin_weight else weight self.weight = weight.pin_memory() if self.pin_weight else weight
# map original id to new id with respect to frequency # map original id to new id with respect to frequency
# id -> cpu_row_idx # id -> cpu_row_idx
self.register_buffer(
"idx_map",
torch.arange(self.num_embeddings, dtype=torch.long, device=torch.cuda.current_device()),
persistent=False,
)
if self._cpu_caching: # cached_idx_map: gpu_row_idx -> cpu_row_idx
self.idx_map = torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev) self.register_buffer("cached_idx_map",
self.cached_idx_map = torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1) torch.empty(self.cuda_row_num, device=torch.cuda.current_device(),
self.inverted_cached_idx = torch.zeros(self.num_embeddings, device=self._cache_dev, dtype=torch.long).fill_(-1),
dtype=torch.long).fill_(-1) persistent=False)
else:
self.register_buffer(
"idx_map",
torch.arange(self.num_embeddings, dtype=torch.long, device=self._cache_dev),
persistent=False,
)
# cached_idx_map: gpu_row_idx -> cpu_row_idx # cpu_row_id -> gpu_row_idx.
self.register_buffer("cached_idx_map", # gpu_row_idx as -1 means cpu_row_id not in CUDA.
torch.empty(self.cuda_row_num, device=self._cache_dev, dtype=torch.long).fill_(-1), self.register_buffer("inverted_cached_idx",
persistent=False) torch.zeros(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(-1),
persistent=False)
# cpu_row_id -> gpu_row_idx. self.evict_backlist = torch.tensor([], device=torch.cuda.current_device())
# gpu_row_idx as -1 means cpu_row_id not in CUDA.
self.register_buffer("inverted_cached_idx",
torch.zeros(self.num_embeddings, device=self._cache_dev,
dtype=torch.long).fill_(-1),
persistent=False)
self.evict_backlist = torch.tensor([], device=self._cache_dev)
# index copy buffer size should less than 10% of cuda weight. # index copy buffer size should less than 10% of cuda weight.
if self.buffer_size > 0: if self.buffer_size > 0:
@ -157,10 +135,8 @@ class CachedParamMgr(torch.nn.Module):
def cpu_weight_data(self, row_idx: int) -> torch.Tensor: def cpu_weight_data(self, row_idx: int) -> torch.Tensor:
""" """
access a row of CPU weight. access a row of CPU weight.
Args: Args:
row_idx (int): the idx of rows row_idx (int): the idx of rows
Returns: Returns:
torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D. torch.Tensor: a piece of memory in CPU weight corresponding to row id's payload. The tensor is 1-D.
""" """
@ -178,14 +154,12 @@ class CachedParamMgr(torch.nn.Module):
"""reorder """reorder
reorder the weight according to ids' frequency in dataset before training. reorder the weight according to ids' frequency in dataset before training.
Execute only once before training, also known as warmup phase. Execute only once before training, also known as warmup phase.
Note: Note:
If you would like to use the DATASET as the eviction strategy, you must call this function. If you would like to use the DATASET as the eviction strategy, you must call this function.
Note: Note:
If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize If you are use the LFU as the eviction strategy, you can skip this function. If you still use this function. It will initialize
The frequency in LFU cache using the dataset statistics. The frequency in LFU cache using the dataset statistics.
Args: Args:
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight. ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
warmup_ratio (float): the amount of chunks preloaded in cuda cache warmup_ratio (float): the amount of chunks preloaded in cuda cache
@ -209,24 +183,24 @@ class CachedParamMgr(torch.nn.Module):
# extract rows from cpu weight # extract rows from cpu weight
if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None: if self._evict_strategy == EvictionStrategy.LFU and ids_freq_mapping is not None:
freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True) freq_value, preload_cpu_ids = torch.topk(ids_freq_mapping, preload_row_num, dim=0, largest=True)
preload_cuda_row_idxs = torch.arange(preload_row_num).to(self._cache_dev) preload_cuda_row_idxs = torch.arange(preload_row_num).cuda()
else: else:
preload_cpu_ids = torch.arange(preload_row_num, device=self.weight.device) preload_cpu_ids = torch.arange(preload_row_num)
preload_cuda_row_idxs = preload_cpu_ids.to(self._cache_dev) preload_cuda_row_idxs = preload_cpu_ids.cuda()
if self.buffer_size > 0: if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0, self.limit_buff_index_copyer.index_copy(0,
src_index=preload_cpu_ids, src_index=preload_cpu_ids,
tgt_index=preload_cuda_row_idxs.cuda(), tgt_index=preload_cuda_row_idxs,
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda() preload_rows = self.weight.view(self.num_embeddings, -1).index_select(0, preload_cpu_ids).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs.cuda(), self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, preload_cuda_row_idxs,
preload_rows) preload_rows)
# update auxiliary info # update auxiliary info
self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.to(self._cache_dev) self.cached_idx_map[preload_cuda_row_idxs] = preload_cpu_ids.cuda()
self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs self.inverted_cached_idx[preload_cpu_ids] = preload_cuda_row_idxs
self._cuda_available_row_num -= preload_row_num self._cuda_available_row_num -= preload_row_num
@ -235,7 +209,7 @@ class CachedParamMgr(torch.nn.Module):
if ids_freq_mapping is None: if ids_freq_mapping is None:
self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0) self.freq_cnter.index_fill_(0, preload_cuda_row_idxs, 0)
else: else:
self.freq_cnter[preload_cuda_row_idxs] = freq_value.to(self._cache_dev) self.freq_cnter[preload_cuda_row_idxs] = freq_value.cuda()
print(f'Cache warmup finished cost {timer.elapsed} sec.') print(f'Cache warmup finished cost {timer.elapsed} sec.')
@ -245,7 +219,7 @@ class CachedParamMgr(torch.nn.Module):
""" """
slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1) slots = torch.nonzero(self.cached_idx_map > -1).squeeze(1)
row_ids = self.cached_idx_map[slots] row_ids = self.cached_idx_map[slots]
rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots.cuda()).cpu() rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, slots).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows) self.weight.view(self.num_embeddings, -1).index_copy_(0, row_ids.cpu(), rows)
self.cached_idx_map.index_fill_(0, slots, -1) self.cached_idx_map.index_fill_(0, slots, -1)
self.inverted_cached_idx.index_fill_(0, row_ids, -1) self.inverted_cached_idx.index_fill_(0, row_ids, -1)
@ -272,10 +246,8 @@ class CachedParamMgr(torch.nn.Module):
""" """
convert ids to indices in self.cuda_cached_weight. convert ids to indices in self.cuda_cached_weight.
Implemented with parallel operations on GPU. Implemented with parallel operations on GPU.
Args: Args:
ids (torch.Tensor): ids from the dataset ids (torch.Tensor): ids from the dataset
Returns: Returns:
torch.Tensor: contains indices in self.cuda_cached_weight torch.Tensor: contains indices in self.cuda_cached_weight
""" """
@ -287,14 +259,12 @@ class CachedParamMgr(torch.nn.Module):
def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor: def prepare_ids(self, ids: torch.Tensor) -> torch.Tensor:
""" """
move the cpu embedding rows w.r.t. ids into CUDA memory move the cpu embedding rows w.r.t. ids into CUDA memory
Args: Args:
ids (torch.Tensor): the ids to be computed ids (torch.Tensor): the ids to be computed
Returns: Returns:
torch.Tensor: indices on the cuda_cached_weight. torch.Tensor: indices on the cuda_cached_weight.
""" """
with record_function(f"(pre-id) get unique indices. cache ratio {self.cuda_row_num / self.num_embeddings}"): with record_function("(zhg) get unique indices"):
ids = ids.to(self._cache_dev)
cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True) cpu_row_idxs, repeat_times = torch.unique(self.idx_map.index_select(0, ids), return_counts=True)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
@ -303,29 +273,26 @@ class CachedParamMgr(torch.nn.Module):
f"Please increase cuda_row_num or decrease the training batch size." f"Please increase cuda_row_num or decrease the training batch size."
self.evict_backlist = cpu_row_idxs self.evict_backlist = cpu_row_idxs
with record_function("(pre-id) get cpu row idxs"): with record_function("(zhg) get cpu row idxs"):
comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, comm_cpu_row_idxs = cpu_row_idxs[torch.isin(cpu_row_idxs, self.cached_idx_map, invert=True)]
self.cached_idx_map,
assume_unique=True,
invert=True)]
self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs)) self.num_hits_history.append(len(cpu_row_idxs) - len(comm_cpu_row_idxs))
self.num_miss_history.append(len(comm_cpu_row_idxs)) self.num_miss_history.append(len(comm_cpu_row_idxs))
self.num_write_back_history.append(0) self.num_write_back_history.append(0)
# move sure the cuda rows will not be evicted! # move sure the cuda rows will not be evicted!
with record_function("(pre-id) cache update"): with record_function("(zhg) cache update"):
self._prepare_rows_on_cuda(comm_cpu_row_idxs) self._prepare_rows_on_cuda(comm_cpu_row_idxs)
self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
with record_function("(pre-id) embed cpu rows idx -> cache gpu row idxs"): self.evict_backlist = torch.tensor([], device=cpu_row_idxs.device, dtype=cpu_row_idxs.dtype)
with record_function("(zhg) embed cpu rows idx -> cache gpu row idxs"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
with record_function("(pre-id) lfu cnter updates"): unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs]
unique_gpu_row_idxs = self.inverted_cached_idx[cpu_row_idxs] self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
self.freq_cnter.scatter_add_(0, unique_gpu_row_idxs, repeat_times)
return gpu_row_idxs return gpu_row_idxs
@ -341,14 +308,13 @@ class CachedParamMgr(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory """prepare rows in cpu_row_idxs on CUDA memory
Args: Args:
cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA cpu_row_idxs (torch.Tensor): the rows to be placed on CUDA
""" """
evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num evict_num = cpu_row_idxs.numel() - self.cuda_available_row_num
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist, assume_unique=True) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET: if self._evict_strategy == EvictionStrategy.DATASET:
# mask method. # mask method.
@ -375,8 +341,7 @@ class CachedParamMgr(torch.nn.Module):
tgt=self.weight.view(self.num_embeddings, -1)) tgt=self.weight.view(self.num_embeddings, -1))
else: else:
# allocate tmp memory on CPU and copy rows on CUDA to CPU. # allocate tmp memory on CPU and copy rows on CUDA to CPU.
rows = self.cuda_cached_weight.view(self.cuda_row_num, rows = self.cuda_cached_weight.view(self.cuda_row_num, -1).index_select(0, evict_gpu_row_idxs).cpu()
-1).index_select(0, evict_gpu_row_idxs.cuda()).cpu()
self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows) self.weight.view(self.num_embeddings, -1).index_copy_(0, evict_info.cpu(), rows)
self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1) self.cached_idx_map.index_fill_(0, evict_gpu_row_idxs, -1)
@ -395,12 +360,12 @@ class CachedParamMgr(torch.nn.Module):
if self.buffer_size > 0: if self.buffer_size > 0:
self.limit_buff_index_copyer.index_copy(0, self.limit_buff_index_copyer.index_copy(0,
src_index=cpu_row_idxs.cpu(), src_index=cpu_row_idxs.cpu(),
tgt_index=slots.cuda(), tgt_index=slots,
src=self.weight.view(self.num_embeddings, -1), src=self.weight.view(self.num_embeddings, -1),
tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1)) tgt=self.cuda_cached_weight.view(self.cuda_row_num, -1))
else: else:
rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda() rows = self.weight.view(self.num_embeddings, -1).index_select(0, cpu_row_idxs.cpu()).cuda()
self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots.cuda(), rows) self.cuda_cached_weight.view(self.cuda_row_num, -1).index_copy_(0, slots, rows)
slot_offsets = slots slot_offsets = slots
self.cached_idx_map[slots] = cpu_row_idxs self.cached_idx_map[slots] = cpu_row_idxs
self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets) self.inverted_cached_idx.index_copy_(0, cpu_row_idxs, slot_offsets)
@ -421,7 +386,6 @@ class CachedParamMgr(torch.nn.Module):
def _evict(self) -> int: def _evict(self) -> int:
""" """
deprecated deprecated
evict one row from cuda to cpu. evict one row from cuda to cpu.
Returns: Returns:
(int) : the slot id be evicted. (int) : the slot id be evicted.
@ -463,9 +427,7 @@ class CachedParamMgr(torch.nn.Module):
def _admit(self, row_id: int): def _admit(self, row_id: int):
""" """
deprecated deprecated
move in row_id to CUDA move in row_id to CUDA
Args: Args:
row_id (int): the id of row to be moved in row_id (int): the id of row to be moved in
""" """
@ -491,4 +453,4 @@ class CachedParamMgr(torch.nn.Module):
self._cuda_available_row_num -= 1 self._cuda_available_row_num -= 1
self._cpu_to_cuda_numel += self.embedding_dim self._cpu_to_cuda_numel += self.embedding_dim
self._cpu_to_cuda_elpase += timer.elapsed self._cpu_to_cuda_elpase += timer.elapsed

View File

@ -90,7 +90,7 @@ def test_reorder_with_freq():
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
weight = torch.rand(num_embed, 2) weight = torch.rand(num_embed, 2)
mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu') mgr = CachedParamMgr(weight, num_chunk)
mgr.reorder(idx_map) mgr.reorder(idx_map)