[FAW] add more docs and fix a warning (#1500)

pull/1501/head^2
Jiarui Fang 2022-08-26 14:10:21 +08:00 committed by GitHub
parent 5a6fd71f90
commit d5085bb317
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 45 additions and 32 deletions

View File

@ -7,32 +7,46 @@ from .copyer import LimitBuffIndexCopyer
from enum import Enum from enum import Enum
import sys import sys
class EvictionStrategy(Enum): class EvictionStrategy(Enum):
LFU = 1 LFU = 1
# dataset aware eviction strategy
DATASET = 2 DATASET = 2
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
""" """
Manage Embedding Weights in Cache on CPU and CUDA memory. Manage Embedding Weights on CPU and CUDA memory uses a software cache.
CPU maintains entire original weight. CPU maintains the entire original weight.
CUDA maintains a fraction of weights used in the upcomming computation. CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
During training, GPU needs to transmit rows between CPU and GPU. During training, GPU needs to transmit embedding rows between CPU and GPU.
Args:
weight (torch.Tensor): the weight of the Embedding layer.
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
`EvictionStrategy.LFU`: use the least frequently used cache.
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
Defaults to EvictionStrategy.DATASET.
""" """
def __init__(self, def __init__(
weight: torch.Tensor, self,
cuda_row_num: int = 0, weight: torch.Tensor,
buffer_size: int = 50_000, cuda_row_num: int = 0,
pin_weight=False, buffer_size: int = 50_000,
evict_strategy=EvictionStrategy.DATASET,) -> None: pin_weight: bool = False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size() self.elem_size_in_byte = weight.element_size()
# weight configure # weight configure
@ -50,17 +64,14 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows. # cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache. # evict the minimal freq value row in cuda cache.
''' '''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary. The last element of `freq_cnter` is set to the maximum value of int.
also, disabled cached_idx_map element maps to -1 by default. The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end, In this way, the not used rows are placed at the end of the sorted.
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
''' '''
self.register_buffer("freq_cnter", self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(), torch.empty(self.num_embeddings + 1,
device=torch.cuda.current_device(),
dtype=torch.long).fill_(0), dtype=torch.long).fill_(0),
persistent=False) persistent=False)
self.freq_cnter[-1] = sys.maxsize self.freq_cnter[-1] = sys.maxsize
@ -168,24 +179,26 @@ class CachedParamMgr(torch.nn.Module):
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7): def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
"""reorder """reorder
reorder the weight according to ids' frequency in dataset before training. reorder the weight according to ids' frequency in dataset before training.
Also Build the IndexMappingTable, aka index_mapping_table. Execute only once before training, also known as warmup phase.
Execute only once before training.
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
Args: Args:
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
warmup_ratio (float): the amount of chunks preloaded in cuda cache warmup_ratio (float): the amount of chunks preloaded in cuda cache
""" """
if ids_freq_mapping is not None: if ids_freq_mapping is not None:
ids_freq_mapping = torch.tensor(ids_freq_mapping) if not isinstance(ids_freq_mapping, torch.Tensor):
ids_freq_mapping = torch.tensor(ids_freq_mapping)
tmp_idx = torch.argsort(ids_freq_mapping, descending=True) tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx) sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx) self.idx_map.data.copy_(sorted_idx)
#initialize freq_cnter if use LFU #initialize freq_cnter if use LFU
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping) self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings) preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
if preload_row_num > 0: if preload_row_num > 0:
with Timer() as timer: with Timer() as timer:
@ -265,7 +278,7 @@ class CachedParamMgr(torch.nn.Module):
with record_function("(zhg) get unique indices"): with record_function("(zhg) get unique indices"):
cpu_row_idxs_original = self.idx_map.index_select(0, ids) cpu_row_idxs_original = self.idx_map.index_select(0, ids)
cpu_row_idxs = torch.unique(cpu_row_idxs_original) cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \ f"which is larger than the presented {self.cuda_row_num}, " \
@ -287,7 +300,7 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk # new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
self._update_freq_cnter(cpu_row_idxs_original) self._update_freq_cnter(cpu_row_idxs_original)
return gpu_row_idxs return gpu_row_idxs
@ -314,7 +327,7 @@ class CachedParamMgr(torch.nn.Module):
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
if self._evict_strategy == EvictionStrategy.DATASET: if self._evict_strategy == EvictionStrategy.DATASET:
# mask method. # mask method.
# set cached_idx_map[invalid_idxs] to -2. # set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim # so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
@ -322,14 +335,14 @@ class CachedParamMgr(torch.nn.Module):
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU: elif self._evict_strategy == EvictionStrategy.LFU:
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -1) self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs] evict_info = self.cached_idx_map[evict_gpu_row_idxs]
if self.buffer_size > 0: if self.buffer_size > 0: