[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)

pull/1500/head
CsRic 2022-08-26 11:24:12 +08:00 committed by GitHub
parent 8b7d6bd5be
commit 0ed2f46131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 25 deletions

View File

@ -5,7 +5,7 @@ from typing import List, Optional
from contexttimer import Timer from contexttimer import Timer
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from enum import Enum from enum import Enum
import sys
class EvictionStrategy(Enum): class EvictionStrategy(Enum):
LFU = 1 LFU = 1
@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
cuda_row_num: int = 0, cuda_row_num: int = 0,
buffer_size: int = 50_000, buffer_size: int = 50_000,
pin_weight=False, pin_weight=False,
evict_strategy=EvictionStrategy.DATASET) -> None: evict_strategy=EvictionStrategy.DATASET,) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size() self.elem_size_in_byte = weight.element_size()
# weight configure # weight configure
@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows. # cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache. # evict the minimal freq value row in cuda cache.
'''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
also, disabled cached_idx_map element maps to -1 by default.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
'''
self.register_buffer("freq_cnter", self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings, device=torch.cuda.current_device(), torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
dtype=torch.long).fill_(0), dtype=torch.long).fill_(0),
persistent=False) persistent=False)
self.freq_cnter[-1] = sys.maxsize
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None: def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
"""_update_freq_cnter """_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter. Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight. cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
""" """
if self._evict_strategy == EvictionStrategy.LFU: if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[cpu_row_idxs] += 1 add_num = torch.bincount(cpu_row_idxs_original)
self.freq_cnter[:add_num.shape[0]] += add_num
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs """_find_evict_gpu_idxs
@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
warmup_ratio (float): the amount of chunks preloaded in cuda cache warmup_ratio (float): the amount of chunks preloaded in cuda cache
""" """
if ids_freq_mapping is not None: if ids_freq_mapping is not None:
ids_freq_mapping = torch.tensor(ids_freq_mapping)
tmp_idx = torch.argsort(ids_freq_mapping, descending=True) tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx) sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx) self.idx_map.data.copy_(sorted_idx)
#initialize freq_cnter if use LFU
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks. # TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory! # As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda # Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight. torch.Tensor: indices on the cuda_cached_weight.
""" """
with record_function("(zhg) get unique indices"): with record_function("(zhg) get unique indices"):
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids)) cpu_row_idxs_original = self.idx_map.index_select(0, ids)
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \ assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \ f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \ f"which is larger than the presented {self.cuda_row_num}, " \
@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk # new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU. # update for LFU.
self._update_freq_cnter(cpu_row_idxs) self._update_freq_cnter(cpu_row_idxs_original)
return gpu_row_idxs return gpu_row_idxs
def _reset_comm_stats(self): def _reset_comm_stats(self):
@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
if evict_num > 0: if evict_num > 0:
with Timer() as timer: with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist) mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET: if self._evict_strategy == EvictionStrategy.DATASET:
# mask method. # mask method.
# set cached_idx_map[invalid_idxs] to -2. # set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim # so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone() backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU: elif self._evict_strategy == EvictionStrategy.LFU:
# another mask method. invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
# set freq_cnter[invalid_idxs] to max backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
# so those idxs will be sorted to end, therefore not being chosen as victim self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
backup_cnter = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num) evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs] evict_info = self.cached_idx_map[evict_gpu_row_idxs]

View File

@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
from colossalai.nn._ops._utils import dual_all_to_all from colossalai.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1: if world_size == 1:
@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
warmup_ratio=0.7, warmup_ratio=0.7,
buffer_size=50_000, buffer_size=50_000,
pin_weight=False, pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET
): ):
self.rank = torch.distributed.get_rank() self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size() self.world_size = torch.distributed.get_world_size()
@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super(ParallelFreqAwareEmbeddingBag, super(ParallelFreqAwareEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping, sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight) warmup_ratio, buffer_size, pin_weight,evict_strategy)
def _weight_alloc(self, dtype, device): def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype) weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)

View File

@ -159,6 +159,9 @@ def test_lfu_strategy():
offsets = torch.tensor([0],device="cuda:0") offsets = torch.tensor([0],device="cuda:0")
# prepare frequency learning info: # prepare frequency learning info:
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets) Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
@ -182,7 +185,7 @@ def test_lfu_strategy():
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed" "LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size): def gather_tensor(tensor, rank, world_size):
gather_list = [] gather_list = []
if rank == 0: if rank == 0:
@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_freq_aware_embed(True) # test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2) # test_parallel_freq_aware_embed(2)
# test_lfu_strategy() test_lfu_strategy()