mirror of https://github.com/hpcaitech/ColossalAI
[FAW] add more docs and fix a warning (#1500)
parent
5a6fd71f90
commit
d5085bb317
|
@ -7,32 +7,46 @@ from .copyer import LimitBuffIndexCopyer
|
|||
from enum import Enum
|
||||
import sys
|
||||
|
||||
|
||||
class EvictionStrategy(Enum):
|
||||
LFU = 1
|
||||
# dataset aware eviction strategy
|
||||
DATASET = 2
|
||||
|
||||
|
||||
class CachedParamMgr(torch.nn.Module):
|
||||
"""
|
||||
Manage Embedding Weights in Cache on CPU and CUDA memory.
|
||||
CPU maintains entire original weight.
|
||||
CUDA maintains a fraction of weights used in the upcomming computation.
|
||||
During training, GPU needs to transmit rows between CPU and GPU.
|
||||
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||
CPU maintains the entire original weight.
|
||||
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
|
||||
During training, GPU needs to transmit embedding rows between CPU and GPU.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): the weight of the Embedding layer.
|
||||
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
|
||||
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
|
||||
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
|
||||
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
|
||||
`EvictionStrategy.LFU`: use the least frequently used cache.
|
||||
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
||||
Defaults to EvictionStrategy.DATASET.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 50_000,
|
||||
pin_weight=False,
|
||||
evict_strategy=EvictionStrategy.DATASET,) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
weight: torch.Tensor,
|
||||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 50_000,
|
||||
pin_weight: bool = False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||
) -> None:
|
||||
super(CachedParamMgr, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.num_embeddings, self.embedding_dim = weight.shape
|
||||
self.cuda_row_num = cuda_row_num
|
||||
self._cuda_available_row_num = self.cuda_row_num
|
||||
self.pin_weight = pin_weight
|
||||
|
||||
|
||||
self.elem_size_in_byte = weight.element_size()
|
||||
|
||||
# weight configure
|
||||
|
@ -50,17 +64,14 @@ class CachedParamMgr(torch.nn.Module):
|
|||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# cpu_row_idx -> frequency, freq of the cpu rows.
|
||||
# evict the minimal freq value row in cuda cache.
|
||||
|
||||
'''
|
||||
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
|
||||
also, disabled cached_idx_map element maps to -1 by default.
|
||||
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
|
||||
not being chosen to evict.
|
||||
|
||||
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
|
||||
The last element of `freq_cnter` is set to the maximum value of int.
|
||||
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
|
||||
In this way, the not used rows are placed at the end of the sorted.
|
||||
'''
|
||||
self.register_buffer("freq_cnter",
|
||||
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
|
||||
torch.empty(self.num_embeddings + 1,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(0),
|
||||
persistent=False)
|
||||
self.freq_cnter[-1] = sys.maxsize
|
||||
|
@ -168,24 +179,26 @@ class CachedParamMgr(torch.nn.Module):
|
|||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||
"""reorder
|
||||
reorder the weight according to ids' frequency in dataset before training.
|
||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
||||
Execute only once before training.
|
||||
Execute only once before training, also known as warmup phase.
|
||||
|
||||
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
|
||||
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
|
||||
|
||||
Args:
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
||||
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
"""
|
||||
if ids_freq_mapping is not None:
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
if not isinstance(ids_freq_mapping, torch.Tensor):
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
#initialize freq_cnter if use LFU
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
|
||||
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
||||
# As cuda_cached_weight is very big. You may not have that much available memory!
|
||||
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
||||
self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
|
||||
|
||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||
if preload_row_num > 0:
|
||||
with Timer() as timer:
|
||||
|
@ -265,7 +278,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
with record_function("(zhg) get unique indices"):
|
||||
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
||||
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
||||
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||
f"which is larger than the presented {self.cuda_row_num}, " \
|
||||
|
@ -287,7 +300,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
# new ids chunk_offset + offset_in_chunk
|
||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
|
||||
|
||||
# update for LFU.
|
||||
self._update_freq_cnter(cpu_row_idxs_original)
|
||||
return gpu_row_idxs
|
||||
|
@ -314,7 +327,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
# mask method.
|
||||
# set cached_idx_map[invalid_idxs] to -2.
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
|
@ -322,14 +335,14 @@ class CachedParamMgr(torch.nn.Module):
|
|||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
if self.buffer_size > 0:
|
||||
|
|
Loading…
Reference in New Issue