Browse Source

[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)

pull/1500/head
CsRic 2 years ago committed by GitHub
parent
commit
0ed2f46131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 51
      colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py
  2. 5
      colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py
  3. 9
      tests/test_layers/test_cache_embedding.py

51
colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py

@ -5,7 +5,7 @@ from typing import List, Optional
from contexttimer import Timer
from .copyer import LimitBuffIndexCopyer
from enum import Enum
import sys
class EvictionStrategy(Enum):
LFU = 1
@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
cuda_row_num: int = 0,
buffer_size: int = 50_000,
pin_weight=False,
evict_strategy=EvictionStrategy.DATASET) -> None:
evict_strategy=EvictionStrategy.DATASET,) -> None:
super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape
self.cuda_row_num = cuda_row_num
self._cuda_available_row_num = self.cuda_row_num
self.pin_weight = pin_weight
self.elem_size_in_byte = weight.element_size()
# weight configure
@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
'''
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
also, disabled cached_idx_map element maps to -1 by default.
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
not being chosen to evict.
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
'''
self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
dtype=torch.long).fill_(0),
persistent=False)
self.freq_cnter[-1] = sys.maxsize
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[cpu_row_idxs] += 1
add_num = torch.bincount(cpu_row_idxs_original)
self.freq_cnter[:add_num.shape[0]] += add_num
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs
@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
warmup_ratio (float): the amount of chunks preloaded in cuda cache
"""
if ids_freq_mapping is not None:
ids_freq_mapping = torch.tensor(ids_freq_mapping)
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
sorted_idx = torch.argsort(tmp_idx)
self.idx_map.data.copy_(sorted_idx)
#initialize freq_cnter if use LFU
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
# As cuda_cached_weight is very big. You may not have that much available memory!
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
torch.Tensor: indices on the cuda_cached_weight.
"""
with record_function("(zhg) get unique indices"):
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids))
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
assert len(cpu_row_idxs) <= self.cuda_row_num, \
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
f"which is larger than the presented {self.cuda_row_num}, " \
@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
self._update_freq_cnter(cpu_row_idxs)
self._update_freq_cnter(cpu_row_idxs_original)
return gpu_row_idxs
def _reset_comm_stats(self):
@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
if evict_num > 0:
with Timer() as timer:
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
if self._evict_strategy == EvictionStrategy.DATASET:
# mask method.
# set cached_idx_map[invalid_idxs] to -2.
# so those idxs will be sorted to end, therefore not being chosen as victim
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
elif self._evict_strategy == EvictionStrategy.LFU:
# another mask method.
# set freq_cnter[invalid_idxs] to max
# so those idxs will be sorted to end, therefore not being chosen as victim
backup_cnter = self.freq_cnter[invalid_idxs].clone()
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs]

5
colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py

@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
from colossalai.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
from .cache_mgr import CachedParamMgr, EvictionStrategy
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
if world_size == 1:
@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
warmup_ratio=0.7,
buffer_size=50_000,
pin_weight=False,
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET
):
self.rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
super(ParallelFreqAwareEmbeddingBag,
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
warmup_ratio, buffer_size, pin_weight)
warmup_ratio, buffer_size, pin_weight,evict_strategy)
def _weight_alloc(self, dtype, device):
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)

9
tests/test_layers/test_cache_embedding.py

@ -159,6 +159,9 @@ def test_lfu_strategy():
offsets = torch.tensor([0],device="cuda:0")
# prepare frequency learning info:
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
@ -182,7 +185,7 @@ def test_lfu_strategy():
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
"LFU strategy behavior failed"
def gather_tensor(tensor, rank, world_size):
gather_list = []
if rank == 0:
@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__':
test_freq_aware_embed(True)
# test_freq_aware_embed(True)
# test_parallel_freq_aware_embed(2)
# test_lfu_strategy()
test_lfu_strategy()
Loading…
Cancel
Save