[FAW] add more docs and fix a warning (#1500)

pull/1501/head^2
Jiarui Fang 2022-08-26 14:10:21 +08:00 committed by GitHub
parent 5a6fd71f90
commit d5085bb317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 45 additions and 32 deletions

View File

@ -7,32 +7,46 @@ from .copyer import LimitBuffIndexCopyer
from enum import Enum
import sys
class EvictionStrategy(Enum):
LFU = 1
# dataset aware eviction strategy
DATASET = 2
class CachedParamMgr(torch.nn.Module):
"""
Manage Embedding Weights in Cache on CPU and CUDA memory.
CPU maintains entire original weight.
CUDA maintains a fraction of weights used in the upcomming computation.
During training, GPU needs to transmit rows between CPU and GPU.
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
CPU maintains the entire original weight.
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit embedding rows between CPU and GPU.
Args:
weight (torch.Tensor): the weight of the Embedding layer.
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
`EvictionStrategy.LFU`: use the least frequently used cache.
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
Defaults to EvictionStrategy.DATASET.
"""
def __init__(self,
weight: torch.Tensor,
cuda_row_num: int = 0,
buffer_size: int = 50_000,
pin_weight=False,
evict_strategy=EvictionStrategy.DATASET,) -> None:
def __init__(
self,
weight: torch.Tensor,
cuda_row_num: int = 0,
buffer_size: int = 50_000,
pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
) -> None:
super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size()
# weight configure
@ -50,17 +64,14 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
'''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
also, disabled cached_idx_map element maps to -1 by default.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
The last element of `freq_cnter` is set to the maximum value of int.
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
In this way, the not used rows are placed at the end of the sorted.
'''
self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
torch.empty(self.num_embeddings + 1,
device=torch.cuda.current_device(),
dtype=torch.long).fill_(0),
persistent=False)
self.freq_cnter[-1] = sys.maxsize
@ -168,24 +179,26 @@ class CachedParamMgr(torch.nn.Module):
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder
reorder the weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table.
Execute only once before training.
Execute only once before training, also known as warmup phase.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if ids_freq_mapping is not None:
ids_freq_mapping = torch.tensor(ids_freq_mapping)
if not isinstance(ids_freq_mapping, torch.Tensor):
ids_freq_mapping = torch.tensor(ids_freq_mapping)
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx)
#initialize freq_cnter if use LFU
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0:
with Timer() as timer:
@ -265,7 +278,7 @@ class CachedParamMgr(torch.nn.Module):
with record_function("(zhg) get unique indices"):
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \
@ -287,7 +300,7 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
self._update_freq_cnter(cpu_row_idxs_original)
return gpu_row_idxs
@ -314,7 +327,7 @@ class CachedParamMgr(torch.nn.Module):
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
@ -322,14 +335,14 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
if self.buffer_size > 0: