mirror of https://github.com/hpcaitech/ColossalAI
[FAW] add more docs and fix a warning (#1500)
parent
5a6fd71f90
commit
d5085bb317
|
@ -7,32 +7,46 @@ from .copyer import LimitBuffIndexCopyer
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
class EvictionStrategy(Enum):
|
class EvictionStrategy(Enum):
|
||||||
LFU = 1
|
LFU = 1
|
||||||
|
# dataset aware eviction strategy
|
||||||
DATASET = 2
|
DATASET = 2
|
||||||
|
|
||||||
|
|
||||||
class CachedParamMgr(torch.nn.Module):
|
class CachedParamMgr(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Manage Embedding Weights in Cache on CPU and CUDA memory.
|
Manage Embedding Weights on CPU and CUDA memory uses a software cache.
|
||||||
CPU maintains entire original weight.
|
CPU maintains the entire original weight.
|
||||||
CUDA maintains a fraction of weights used in the upcomming computation.
|
CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`.
|
||||||
During training, GPU needs to transmit rows between CPU and GPU.
|
During training, GPU needs to transmit embedding rows between CPU and GPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): the weight of the Embedding layer.
|
||||||
|
cuda_row_num (int, optional): the number of rows cached in CUDA memory. Defaults to 0.
|
||||||
|
buffer_size (int, optional): the number of rows in a data transmitter buffer. Defaults to 50_000.
|
||||||
|
pin_weight (bool, optional): use pin memory to store the cpu weight. If set `True`, the cpu memory usage will increase largely. Defaults to False.
|
||||||
|
evict_strategy (EvictionStrategy, optional): the eviction strategy. There are two options.
|
||||||
|
`EvictionStrategy.LFU`: use the least frequently used cache.
|
||||||
|
`EvictionStrategy.DATASET`: use the stats collected from the target dataset. It usually leads to less cpu-gpu communication volume.
|
||||||
|
Defaults to EvictionStrategy.DATASET.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
weight: torch.Tensor,
|
self,
|
||||||
cuda_row_num: int = 0,
|
weight: torch.Tensor,
|
||||||
buffer_size: int = 50_000,
|
cuda_row_num: int = 0,
|
||||||
pin_weight=False,
|
buffer_size: int = 50_000,
|
||||||
evict_strategy=EvictionStrategy.DATASET,) -> None:
|
pin_weight: bool = False,
|
||||||
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET,
|
||||||
|
) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.num_embeddings, self.embedding_dim = weight.shape
|
self.num_embeddings, self.embedding_dim = weight.shape
|
||||||
self.cuda_row_num = cuda_row_num
|
self.cuda_row_num = cuda_row_num
|
||||||
self._cuda_available_row_num = self.cuda_row_num
|
self._cuda_available_row_num = self.cuda_row_num
|
||||||
self.pin_weight = pin_weight
|
self.pin_weight = pin_weight
|
||||||
|
|
||||||
self.elem_size_in_byte = weight.element_size()
|
self.elem_size_in_byte = weight.element_size()
|
||||||
|
|
||||||
# weight configure
|
# weight configure
|
||||||
|
@ -50,17 +64,14 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# cpu_row_idx -> frequency, freq of the cpu rows.
|
# cpu_row_idx -> frequency, freq of the cpu rows.
|
||||||
# evict the minimal freq value row in cuda cache.
|
# evict the minimal freq value row in cuda cache.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
|
The last element of `freq_cnter` is set to the maximum value of int.
|
||||||
also, disabled cached_idx_map element maps to -1 by default.
|
The rows store nothing (not used) in the `self.cuda_weight` whose value is -1 in `self.cached_idx_map`.
|
||||||
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
|
In this way, the not used rows are placed at the end of the sorted.
|
||||||
not being chosen to evict.
|
|
||||||
|
|
||||||
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
|
|
||||||
'''
|
'''
|
||||||
self.register_buffer("freq_cnter",
|
self.register_buffer("freq_cnter",
|
||||||
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
|
torch.empty(self.num_embeddings + 1,
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
dtype=torch.long).fill_(0),
|
dtype=torch.long).fill_(0),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
self.freq_cnter[-1] = sys.maxsize
|
self.freq_cnter[-1] = sys.maxsize
|
||||||
|
@ -168,24 +179,26 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7):
|
||||||
"""reorder
|
"""reorder
|
||||||
reorder the weight according to ids' frequency in dataset before training.
|
reorder the weight according to ids' frequency in dataset before training.
|
||||||
Also Build the IndexMappingTable, aka index_mapping_table.
|
Execute only once before training, also known as warmup phase.
|
||||||
Execute only once before training.
|
|
||||||
|
:NOTE If you would like to use the DATASET as the eviction strategy, you must call this function.
|
||||||
|
:NOTE If you are use the LFU as the eviction strategy, you can skip this function. The `freq_cnter` will be initialized as all zeros.
|
||||||
|
You can also call this function to inialized the `freq_cnter` with dataset frequency statistics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq. if None no reorder
|
ids_freq_mapping (List[int]): a list, whose offset is id number, value is freq. if None then not reorder the cpu weight.
|
||||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
if ids_freq_mapping is not None:
|
if ids_freq_mapping is not None:
|
||||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
if not isinstance(ids_freq_mapping, torch.Tensor):
|
||||||
|
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||||
sorted_idx = torch.argsort(tmp_idx)
|
sorted_idx = torch.argsort(tmp_idx)
|
||||||
self.idx_map.data.copy_(sorted_idx)
|
self.idx_map.data.copy_(sorted_idx)
|
||||||
#initialize freq_cnter if use LFU
|
#initialize freq_cnter if use LFU
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
|
self.freq_cnter[:-1], _ = torch.sort(ids_freq_mapping)
|
||||||
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
|
||||||
# As cuda_cached_weight is very big. You may not have that much available memory!
|
|
||||||
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
|
||||||
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
preload_row_num = min(int(np.ceil(self.cuda_row_num * warmup_ratio)), self.num_embeddings)
|
||||||
if preload_row_num > 0:
|
if preload_row_num > 0:
|
||||||
with Timer() as timer:
|
with Timer() as timer:
|
||||||
|
@ -265,7 +278,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
with record_function("(zhg) get unique indices"):
|
with record_function("(zhg) get unique indices"):
|
||||||
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
||||||
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
||||||
|
|
||||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||||
f"which is larger than the presented {self.cuda_row_num}, " \
|
f"which is larger than the presented {self.cuda_row_num}, " \
|
||||||
|
@ -287,7 +300,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# new ids chunk_offset + offset_in_chunk
|
# new ids chunk_offset + offset_in_chunk
|
||||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||||
|
|
||||||
# update for LFU.
|
# update for LFU.
|
||||||
self._update_freq_cnter(cpu_row_idxs_original)
|
self._update_freq_cnter(cpu_row_idxs_original)
|
||||||
return gpu_row_idxs
|
return gpu_row_idxs
|
||||||
|
@ -314,7 +327,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||||
|
|
||||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||||
# mask method.
|
# mask method.
|
||||||
# set cached_idx_map[invalid_idxs] to -2.
|
# set cached_idx_map[invalid_idxs] to -2.
|
||||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||||
|
@ -322,14 +335,14 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||||
|
|
||||||
if self.buffer_size > 0:
|
if self.buffer_size > 0:
|
||||||
|
|
Loading…
Reference in New Issue