mirror of https://github.com/hpcaitech/ColossalAI
[FAW] FAW embedding use LRU as eviction strategy intialized with dataset stats (#1494)
parent
8b7d6bd5be
commit
0ed2f46131
|
@ -5,7 +5,7 @@ from typing import List, Optional
|
||||||
from contexttimer import Timer
|
from contexttimer import Timer
|
||||||
from .copyer import LimitBuffIndexCopyer
|
from .copyer import LimitBuffIndexCopyer
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import sys
|
||||||
|
|
||||||
class EvictionStrategy(Enum):
|
class EvictionStrategy(Enum):
|
||||||
LFU = 1
|
LFU = 1
|
||||||
|
@ -25,14 +25,14 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
cuda_row_num: int = 0,
|
cuda_row_num: int = 0,
|
||||||
buffer_size: int = 50_000,
|
buffer_size: int = 50_000,
|
||||||
pin_weight=False,
|
pin_weight=False,
|
||||||
evict_strategy=EvictionStrategy.DATASET) -> None:
|
evict_strategy=EvictionStrategy.DATASET,) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.num_embeddings, self.embedding_dim = weight.shape
|
self.num_embeddings, self.embedding_dim = weight.shape
|
||||||
self.cuda_row_num = cuda_row_num
|
self.cuda_row_num = cuda_row_num
|
||||||
self._cuda_available_row_num = self.cuda_row_num
|
self._cuda_available_row_num = self.cuda_row_num
|
||||||
self.pin_weight = pin_weight
|
self.pin_weight = pin_weight
|
||||||
|
|
||||||
self.elem_size_in_byte = weight.element_size()
|
self.elem_size_in_byte = weight.element_size()
|
||||||
|
|
||||||
# weight configure
|
# weight configure
|
||||||
|
@ -50,12 +50,22 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# cpu_row_idx -> frequency, freq of the cpu rows.
|
# cpu_row_idx -> frequency, freq of the cpu rows.
|
||||||
# evict the minimal freq value row in cuda cache.
|
# evict the minimal freq value row in cuda cache.
|
||||||
|
|
||||||
|
'''
|
||||||
|
during cache eviction, if a cached_idx_map element maps to a masked cpu_idx, we re-map that element to -1 temporary.
|
||||||
|
also, disabled cached_idx_map element maps to -1 by default.
|
||||||
|
freq_cnter[-1], the last element, should ALWAYS be MAX VALUE so those masked or disabled idxs will be argsorted to end,
|
||||||
|
not being chosen to evict.
|
||||||
|
|
||||||
|
ZH: freq_cnter的最后一位设为了最大值, 不该被选为换出的cache idx都是-1, 指向这个最大值, 所以排序时在队尾, 不会被选中换出
|
||||||
|
'''
|
||||||
self.register_buffer("freq_cnter",
|
self.register_buffer("freq_cnter",
|
||||||
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
|
torch.empty(self.num_embeddings + 1, device=torch.cuda.current_device(),
|
||||||
dtype=torch.long).fill_(0),
|
dtype=torch.long).fill_(0),
|
||||||
persistent=False)
|
persistent=False)
|
||||||
|
self.freq_cnter[-1] = sys.maxsize
|
||||||
|
|
||||||
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
|
def _update_freq_cnter(self, cpu_row_idxs_original: torch.Tensor) -> None:
|
||||||
"""_update_freq_cnter
|
"""_update_freq_cnter
|
||||||
|
|
||||||
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
||||||
|
@ -64,7 +74,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
||||||
"""
|
"""
|
||||||
if self._evict_strategy == EvictionStrategy.LFU:
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
self.freq_cnter[cpu_row_idxs] += 1
|
add_num = torch.bincount(cpu_row_idxs_original)
|
||||||
|
self.freq_cnter[:add_num.shape[0]] += add_num
|
||||||
|
|
||||||
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
||||||
"""_find_evict_gpu_idxs
|
"""_find_evict_gpu_idxs
|
||||||
|
@ -165,10 +176,13 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
warmup_ratio (float): the amount of chunks preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
if ids_freq_mapping is not None:
|
if ids_freq_mapping is not None:
|
||||||
|
ids_freq_mapping = torch.tensor(ids_freq_mapping)
|
||||||
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
tmp_idx = torch.argsort(ids_freq_mapping, descending=True)
|
||||||
sorted_idx = torch.argsort(tmp_idx)
|
sorted_idx = torch.argsort(tmp_idx)
|
||||||
self.idx_map.data.copy_(sorted_idx)
|
self.idx_map.data.copy_(sorted_idx)
|
||||||
|
#initialize freq_cnter if use LFU
|
||||||
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
self.freq_cnter[:-1],_ = torch.sort(ids_freq_mapping)
|
||||||
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
# TODO() The following code will allocate extra CUDA memory. preload_row_num * chunks.
|
||||||
# As cuda_cached_weight is very big. You may not have that much available memory!
|
# As cuda_cached_weight is very big. You may not have that much available memory!
|
||||||
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
# Warmup the cuda cache by moving high freq chunks (lowest chunk id) to cuda
|
||||||
|
@ -249,8 +263,9 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
torch.Tensor: indices on the cuda_cached_weight.
|
torch.Tensor: indices on the cuda_cached_weight.
|
||||||
"""
|
"""
|
||||||
with record_function("(zhg) get unique indices"):
|
with record_function("(zhg) get unique indices"):
|
||||||
cpu_row_idxs = torch.unique(self.idx_map.index_select(0, ids))
|
cpu_row_idxs_original = self.idx_map.index_select(0, ids)
|
||||||
|
cpu_row_idxs = torch.unique(cpu_row_idxs_original)
|
||||||
|
|
||||||
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
assert len(cpu_row_idxs) <= self.cuda_row_num, \
|
||||||
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
f"the input indices pull {len(cpu_row_idxs)} chunks, " \
|
||||||
f"which is larger than the presented {self.cuda_row_num}, " \
|
f"which is larger than the presented {self.cuda_row_num}, " \
|
||||||
|
@ -272,10 +287,9 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# new ids chunk_offset + offset_in_chunk
|
# new ids chunk_offset + offset_in_chunk
|
||||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||||
|
|
||||||
# update for LFU.
|
# update for LFU.
|
||||||
self._update_freq_cnter(cpu_row_idxs)
|
self._update_freq_cnter(cpu_row_idxs_original)
|
||||||
|
|
||||||
return gpu_row_idxs
|
return gpu_row_idxs
|
||||||
|
|
||||||
def _reset_comm_stats(self):
|
def _reset_comm_stats(self):
|
||||||
|
@ -298,26 +312,23 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
if evict_num > 0:
|
if evict_num > 0:
|
||||||
with Timer() as timer:
|
with Timer() as timer:
|
||||||
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
mask_cpu_row_idx = torch.isin(self.cached_idx_map, self.evict_backlist)
|
||||||
|
|
||||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
|
||||||
|
|
||||||
if self._evict_strategy == EvictionStrategy.DATASET:
|
if self._evict_strategy == EvictionStrategy.DATASET:
|
||||||
# mask method.
|
# mask method.
|
||||||
# set cached_idx_map[invalid_idxs] to -2.
|
# set cached_idx_map[invalid_idxs] to -2.
|
||||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
# so those idxs will be sorted to end, therefore not being chosen as victim
|
||||||
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||||
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
elif self._evict_strategy == EvictionStrategy.LFU:
|
elif self._evict_strategy == EvictionStrategy.LFU:
|
||||||
# another mask method.
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||||
# set freq_cnter[invalid_idxs] to max
|
backup_idxs = self.cached_idx_map[mask_cpu_row_idx].clone()
|
||||||
# so those idxs will be sorted to end, therefore not being chosen as victim
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -1)
|
||||||
backup_cnter = self.freq_cnter[invalid_idxs].clone()
|
|
||||||
self.freq_cnter.index_fill_(0, invalid_idxs, torch.max(self.freq_cnter) + 1) # or can we use a confident max value?
|
|
||||||
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
self.freq_cnter.index_copy_(0,invalid_idxs,backup_cnter)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
from colossalai.nn._ops._utils import dual_all_to_all
|
from colossalai.nn._ops._utils import dual_all_to_all
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor
|
||||||
|
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||||
|
|
||||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
|
@ -48,6 +48,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
pin_weight=False,
|
pin_weight=False,
|
||||||
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET
|
||||||
):
|
):
|
||||||
self.rank = torch.distributed.get_rank()
|
self.rank = torch.distributed.get_rank()
|
||||||
self.world_size = torch.distributed.get_world_size()
|
self.world_size = torch.distributed.get_world_size()
|
||||||
|
@ -59,7 +60,7 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag):
|
||||||
super(ParallelFreqAwareEmbeddingBag,
|
super(ParallelFreqAwareEmbeddingBag,
|
||||||
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq,
|
||||||
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
sparse, _weight, mode, include_last_offset, dtype, device, cuda_row_num, ids_freq_mapping,
|
||||||
warmup_ratio, buffer_size, pin_weight)
|
warmup_ratio, buffer_size, pin_weight,evict_strategy)
|
||||||
|
|
||||||
def _weight_alloc(self, dtype, device):
|
def _weight_alloc(self, dtype, device):
|
||||||
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
weight = torch.empty(self.num_embeddings, self.embedding_dim_per_partition, device=device, dtype=dtype)
|
||||||
|
|
|
@ -159,6 +159,9 @@ def test_lfu_strategy():
|
||||||
offsets = torch.tensor([0],device="cuda:0")
|
offsets = torch.tensor([0],device="cuda:0")
|
||||||
|
|
||||||
# prepare frequency learning info:
|
# prepare frequency learning info:
|
||||||
|
Bag.forward(torch.tensor([2],device="cuda:0"),offsets)
|
||||||
|
Bag.forward(torch.tensor([1,2],device="cuda:0"),offsets)
|
||||||
|
Bag.forward(torch.tensor([0,2],device="cuda:0"),offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||||
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
Bag.forward(torch.tensor([0,1,2],device="cuda:0"),offsets)
|
||||||
|
@ -182,7 +185,7 @@ def test_lfu_strategy():
|
||||||
|
|
||||||
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
|
||||||
"LFU strategy behavior failed"
|
"LFU strategy behavior failed"
|
||||||
|
|
||||||
def gather_tensor(tensor, rank, world_size):
|
def gather_tensor(tensor, rank, world_size):
|
||||||
gather_list = []
|
gather_list = []
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -273,6 +276,6 @@ def test_parallel_freq_aware_embed(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_freq_aware_embed(True)
|
# test_freq_aware_embed(True)
|
||||||
# test_parallel_freq_aware_embed(2)
|
# test_parallel_freq_aware_embed(2)
|
||||||
# test_lfu_strategy()
|
test_lfu_strategy()
|
Loading…
Reference in New Issue