[FAW] init an LFU implementation for FAW (#1488)

pull/1493/head
Jiarui Fang 2022-08-24 17:37:22 +08:00 committed by GitHub
parent 32efe8e740
commit cde7b8a5b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 39 deletions

View File

@ -3,10 +3,10 @@ from .linear import ColoLinear
from .embedding import ColoEmbedding from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy
__all__ = [ __all__ = [
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer' 'LimitBuffIndexCopyer', 'EvictionStrategy'
] ]

View File

@ -1,6 +1,9 @@
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr, EvictionStrategy
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag from .freq_aware_embedding import FreqAwareEmbeddingBag
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag'] __all__ = [
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
'EvictionStrategy'
]

View File

@ -4,6 +4,12 @@ from torch.profiler import record_function
from typing import List, Optional from typing import List, Optional
from contexttimer import Timer from contexttimer import Timer
from .copyer import LimitBuffIndexCopyer from .copyer import LimitBuffIndexCopyer
from enum import Enum
class EvictionStrategy(Enum):
LFU = 1
DATASET = 2
class CachedParamMgr(torch.nn.Module): class CachedParamMgr(torch.nn.Module):
@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module):
weight: torch.Tensor, weight: torch.Tensor,
cuda_row_num: int = 0, cuda_row_num: int = 0,
buffer_size: int = 50_000, buffer_size: int = 50_000,
pin_weight=False) -> None: pin_weight=False,
evict_strategy=EvictionStrategy.DATASET) -> None:
super(CachedParamMgr, self).__init__() super(CachedParamMgr, self).__init__()
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.num_embeddings, self.embedding_dim = weight.shape self.num_embeddings, self.embedding_dim = weight.shape
@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module):
self.input_id_percent_in_load_chunk = [] self.input_id_percent_in_load_chunk = []
self._reset_comm_stats() self._reset_comm_stats()
self._evict_strategy = evict_strategy
if self._evict_strategy == EvictionStrategy.LFU:
# cpu_row_idx -> frequency, freq of the cpu rows.
# evict the minimal freq value row in cuda cache.
self.register_buffer("freq_cnter",
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
dtype=torch.long).fill_(0),
persistent=False)
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
"""_update_freq_cnter
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
Args:
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
"""
if self._evict_strategy == EvictionStrategy.LFU:
self.freq_cnter[cpu_row_idxs] += 1
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
"""_find_evict_gpu_idxs
Find the gpu idxs to be evicted, according to their freq.
Args:
evict_num (int): how many rows has to be evicted
Returns:
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
"""
if self._evict_strategy == EvictionStrategy.LFU:
# find the minimal evict_num freq entries in cached_idx_map
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
return self.cached_idx_map[evict_gpu_row_idxs]
elif self._evict_strategy == EvictionStrategy.DATASET:
# cached_idx_map itself implies the priority of eviction.
# The value of self.cached_idx_map represents cpu_row_idx.
# The larger it is, the less frequently it will appear in the dataset,
# and the higher its eviction priority will be.
return torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
else:
raise TypeError
def _init_weight(self, weight): def _init_weight(self, weight):
if self.cuda_row_num > 0: if self.cuda_row_num > 0:
# Enable cache with introducing auxiliary data structures # Enable cache with introducing auxiliary data structures
@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
# new ids chunk_offset + offset_in_chunk # new ids chunk_offset + offset_in_chunk
with record_function("(zhg) embed idx -> cache chunk id"): with record_function("(zhg) embed idx -> cache chunk id"):
gpu_row_idxs = self._id_to_cached_cuda_id(ids) gpu_row_idxs = self._id_to_cached_cuda_id(ids)
# update for LFU.
self._update_freq_cnter(cpu_row_idxs)
return gpu_row_idxs return gpu_row_idxs
def _reset_comm_stats(self): def _reset_comm_stats(self):
@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None: def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
"""prepare rows in cpu_row_idxs on CUDA memory """prepare rows in cpu_row_idxs on CUDA memory
Args: Args:
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
""" """
@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module):
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1) invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
self.cached_idx_map.index_fill_(0, invalid_idxs, -2) self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs) self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
evict_info = self.cached_idx_map[evict_gpu_row_idxs] evict_info = self.cached_idx_map[evict_gpu_row_idxs]
@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module):
self._cpu_to_cuda_numel += weight_size self._cpu_to_cuda_numel += weight_size
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB") # print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
def _evict(self) -> int: def _evict(self) -> int:
""" """
deprecated
evict one chunk from cuda to cpu. evict one chunk from cuda to cpu.
Returns: Returns:
(int) : the slot id be evicted. (int) : the slot id be evicted.
@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module):
# self.num_write_back_history[-1] += 1 # self.num_write_back_history[-1] += 1
return max_cpu_row_idx return max_cpu_row_idx
def _find_free_cuda_row(self) -> int:
if self._cuda_available_row_num == 0:
return -1
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
return candidates[0].item()
@torch.no_grad() @torch.no_grad()
def _admit(self, row_id: int): def _admit(self, row_id: int):
""" """
deprecated
move in row_id to CUDA move in row_id to CUDA
Args: Args:

View File

@ -3,35 +3,35 @@ import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr, EvictionStrategy
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
class FreqAwareEmbeddingBag(BaseEmbeddingBag): class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def __init__( def __init__(self,
self, num_embeddings,
num_embeddings, embedding_dim,
embedding_dim, padding_idx=None,
padding_idx=None, max_norm=None,
max_norm=None, norm_type=2.,
norm_type=2., scale_grad_by_freq=False,
scale_grad_by_freq=False, sparse=False,
sparse=False, _weight=None,
_weight=None, mode='mean',
mode='mean', include_last_offset=False,
include_last_offset=False, dtype=None,
dtype=None, device=None,
device=None, cuda_row_num=0,
cuda_row_num=0, ids_freq_mapping=None,
ids_freq_mapping=None, warmup_ratio=0.7,
warmup_ratio=0.7, buffer_size=50_000,
buffer_size=50_000, pin_weight=False,
pin_weight=False, evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
):
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
scale_grad_by_freq, sparse, mode, include_last_offset) scale_grad_by_freq, sparse, mode, include_last_offset)
self.evict_strategy = evict_strategy
if _weight is None: if _weight is None:
_weight = self._weight_alloc(dtype, device) _weight = self._weight_alloc(dtype, device)
@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
ids_freq_mapping (List[int]): a list, idx is id number, value is freq ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache warmup_ratio (float): the amount of rows preloaded in cuda cache
""" """
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight) self.cache_weight_mgr = CachedParamMgr(weight,
cuda_row_num,
buffer_size,
pin_weight,
evict_strategy=self.evict_strategy)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):

View File

@ -12,7 +12,7 @@ from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
ColoTensor, ColoTensorSpec ColoTensor, ColoTensorSpec
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
NUM_EMBED, EMBED_DIM = 10, 8 NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8
@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
return indices, offsets return indices, offsets
@pytest.mark.skip
def test_cachemgr(): def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128) model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda # 10 chunks, 5 in cuda
@ -98,14 +99,17 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
def test_freq_aware_embed(): @pytest.mark.parametrize('use_LFU', [True, False])
def test_freq_aware_embed(use_LFU: bool):
device = torch.device('cuda', 0) device = torch.device('cuda', 0)
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
model = FreqAwareEmbeddingBag(NUM_EMBED, model = FreqAwareEmbeddingBag(NUM_EMBED,
EMBED_DIM, EMBED_DIM,
mode='mean', mode='mean',
include_last_offset=True, include_last_offset=True,
cuda_row_num=BATCH_SIZE * 2, cuda_row_num=BATCH_SIZE * 2,
ids_freq_mapping=None).to(device) ids_freq_mapping=None,
evict_strategy=evict_strategy).to(device)
assert model.weight.shape[0] == NUM_EMBED assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_cachemgr() test_freq_aware_embed(True)
# test_freq_aware_embed()
# test_parallel_freq_aware_embed(2) # test_parallel_freq_aware_embed(2)