mirror of https://github.com/hpcaitech/ColossalAI
[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)
parent
8b7d6bd5be
commit
0ed2f46131
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
|||
from contexttimer import Timer
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
|
||||
class EvictionStrategy(Enum):
|
||||
LFU = 1
|
||||
|
@ -25,7 +25,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
cuda_row_num: int = 0,
|
||||
buffer_size: int = 50_000,
|
||||
pin_weight=False,
|
||||
evict_strategy=EvictionStrategy.DATASET) -> None:
|
||||
evict_strategy=EvictionStrategy.DATASET,) -> None:
|
||||
super(CachedParamMgr, self).__init__()
|
||||
self.buffer_size = buffer_size
|
||||
self.num_embeddings, self.embedding_dim = weight.shape
|
||||
|
@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
|
|||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
# cpu_row_idx -> frequency, freq of the cpu rows.
|
||||
# evict the minimal freq value row in cuda cache.
|
||||
|
||||
'''
|
||||
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
|
||||
also, disabled cached_idx_map element maps to -1 by default.
|
||||
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
|
||||
not being chosen to evict.
|
||||
|
||||
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
|
||||
'''
|
||||
self.register_buffer("freq_cnter",
|
||||
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
|
||||
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
|
||||
dtype=torch.long).fill_(0),
|
||||
persistent=False)
|
||||
self.freq_cnter[-1] = sys.maxsize
|
||||
|
||||
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
|
||||
"""_update_freq_cnter
|
||||
|
||||
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
||||
|
@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
||||
"""
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[cpu_row_idxs] += 1
|
||||
add_num = torch.bincount(cpu_row_idxs_original)
|
||||
self.freq_cnter[:add_num.shape[0]] += add_num
|
||||
|
||||
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
||||
"""_find_evict_gpu_idxs
|
||||
|
@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
|
|||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||
"""
|
||||
if ids_freq_mapping is not None:
|
||||
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||
sorted_idx = torch.argsort(tmp_idx)
|
||||
self.idx_map.data.copy_(sorted_idx)
|
||||
|
||||
#initialize freq_cnter if use LFU
|
||||
if self._evict_strategy == EvictionStrategy.LFU:
|
||||
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
|
||||
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
||||
# As cuda_cached_weight is very big. You may not have that much available memory!
|
||||
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
||||
|
@ -249,7 +263,8 @@ class CachedParamMgr(torch.nn.Module):
|
|||
torch.Tensor: indices on the cuda_cached_weight.
|
||||
"""
|
||||
with record_function("(zhg) get unique indices"):
|
||||
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids))
|
||||
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
||||
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
||||
|
||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||
|
@ -274,8 +289,7 @@ class CachedParamMgr(torch.nn.Module):
|
|||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||
|
||||
# update for LFU.
|
||||
self._update_freq_cnter(cpu_row_idxs)
|
||||
|
||||
self._update_freq_cnter(cpu_row_idxs_original)
|
||||
return gpu_row_idxs
|
||||
|
||||
def _reset_comm_stats(self):
|
||||
|
@ -299,25 +313,22 @@ class CachedParamMgr(torch.nn.Module):
|
|||
with Timer() as timer:
|
||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
|
||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||
# mask method.
|
||||
# set cached_idx_map[invalid_idxs] to -2.
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||
# another mask method.
|
||||
# set freq_cnter[invalid_idxs] to max
|
||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||
backup_cnter = self.freq_cnter[invalid_idxs].clone()
|
||||
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
|
||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
|
||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||
|
||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
|
|||
from colossalai.nn._ops._utils import dual_all_to_all
|
||||
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||
|
||||
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||
|
||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||
if world_size == 1:
|
||||
|
@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
warmup_ratio=0.7,
|
||||
buffer_size=50_000,
|
||||
pin_weight=False,
|
||||
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET
|
||||
):
|
||||
self.rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
|
@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
|||
super(ParallelFreqAwareEmbeddingBag,
|
||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
||||
warmup_ratio, buffer_size, pin_weight)
|
||||
warmup_ratio, buffer_size, pin_weight,evict_strategy)
|
||||
|
||||
def _weight_alloc(self, dtype, device):
|
||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||
|
|
|
@ -159,6 +159,9 @@ def test_lfu_strategy():
|
|||
offsets = torch.tensor([0],device="cuda:0")
|
||||
|
||||
# prepare frequency learning info:
|
||||
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||
|
@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_freq_aware_embed(True)
|
||||
# test_freq_aware_embed(True)
|
||||
# test_parallel_freq_aware_embed(2)
|
||||
# test_lfu_strategy()
|
||||
test_lfu_strategy()
|
Loading…
Reference in New Issue