mirror of https://github.com/hpcaitech/ColossalAI
[FAW] init an LFU implementation for FAW (#1488)
parent
32efe8e740
commit
cde7b8a5b8
|
@ -3,10 +3,10 @@ from .linear import ColoLinear
|
||||||
from .embedding import ColoEmbedding
|
from .embedding import ColoEmbedding
|
||||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||||
|
|
||||||
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer
|
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||||
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
|
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
|
||||||
'LimitBuffIndexCopyer'
|
'LimitBuffIndexCopyer', 'EvictionStrategy'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from .cache_mgr import CachedParamMgr
|
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||||
from .copyer import LimitBuffIndexCopyer
|
from .copyer import LimitBuffIndexCopyer
|
||||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||||
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag
|
||||||
|
|
||||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag']
|
__all__ = [
|
||||||
|
'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag',
|
||||||
|
'EvictionStrategy'
|
||||||
|
]
|
||||||
|
|
|
@ -4,6 +4,12 @@ from torch.profiler import record_function
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from contexttimer import Timer
|
from contexttimer import Timer
|
||||||
from .copyer import LimitBuffIndexCopyer
|
from .copyer import LimitBuffIndexCopyer
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EvictionStrategy(Enum):
|
||||||
|
LFU = 1
|
||||||
|
DATASET = 2
|
||||||
|
|
||||||
|
|
||||||
class CachedParamMgr(torch.nn.Module):
|
class CachedParamMgr(torch.nn.Module):
|
||||||
|
@ -18,7 +24,8 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
cuda_row_num: int = 0,
|
cuda_row_num: int = 0,
|
||||||
buffer_size: int = 50_000,
|
buffer_size: int = 50_000,
|
||||||
pin_weight=False) -> None:
|
pin_weight=False,
|
||||||
|
evict_strategy=EvictionStrategy.DATASET) -> None:
|
||||||
super(CachedParamMgr, self).__init__()
|
super(CachedParamMgr, self).__init__()
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
self.num_embeddings, self.embedding_dim = weight.shape
|
self.num_embeddings, self.embedding_dim = weight.shape
|
||||||
|
@ -38,6 +45,51 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self.input_id_percent_in_load_chunk = []
|
self.input_id_percent_in_load_chunk = []
|
||||||
self._reset_comm_stats()
|
self._reset_comm_stats()
|
||||||
|
|
||||||
|
self._evict_strategy = evict_strategy
|
||||||
|
|
||||||
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
# cpu_row_idx -> frequency, freq of the cpu rows.
|
||||||
|
# evict the minimal freq value row in cuda cache.
|
||||||
|
self.register_buffer("freq_cnter",
|
||||||
|
torch.empty(self.num_embeddings, device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.long).fill_(0),
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
def _update_freq_cnter(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||||
|
"""_update_freq_cnter
|
||||||
|
|
||||||
|
Update the frequency valude w.r.t. the cpu_row_ids in self.freq_cnter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_row_idxs (torch.Tensor): a list of indices of cpu weight.
|
||||||
|
"""
|
||||||
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
self.freq_cnter[cpu_row_idxs] += 1
|
||||||
|
|
||||||
|
def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor:
|
||||||
|
"""_find_evict_gpu_idxs
|
||||||
|
|
||||||
|
Find the gpu idxs to be evicted, according to their freq.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
evict_num (int): how many rows has to be evicted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: a list tensor (1D), contains the gpu_row_idxs.
|
||||||
|
"""
|
||||||
|
if self._evict_strategy == EvictionStrategy.LFU:
|
||||||
|
# find the minimal evict_num freq entries in cached_idx_map
|
||||||
|
evict_gpu_row_idxs = torch.argsort(self.freq_cnter[self.cached_idx_map])[:evict_num]
|
||||||
|
return self.cached_idx_map[evict_gpu_row_idxs]
|
||||||
|
elif self._evict_strategy == EvictionStrategy.DATASET:
|
||||||
|
# cached_idx_map itself implies the priority of eviction.
|
||||||
|
# The value of self.cached_idx_map represents cpu_row_idx.
|
||||||
|
# The larger it is, the less frequently it will appear in the dataset,
|
||||||
|
# and the higher its eviction priority will be.
|
||||||
|
return torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
def _init_weight(self, weight):
|
def _init_weight(self, weight):
|
||||||
if self.cuda_row_num > 0:
|
if self.cuda_row_num > 0:
|
||||||
# Enable cache with introducing auxiliary data structures
|
# Enable cache with introducing auxiliary data structures
|
||||||
|
@ -220,6 +272,10 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# new ids chunk_offset + offset_in_chunk
|
# new ids chunk_offset + offset_in_chunk
|
||||||
with record_function("(zhg) embed idx -> cache chunk id"):
|
with record_function("(zhg) embed idx -> cache chunk id"):
|
||||||
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
gpu_row_idxs = self._id_to_cached_cuda_id(ids)
|
||||||
|
|
||||||
|
# update for LFU.
|
||||||
|
self._update_freq_cnter(cpu_row_idxs)
|
||||||
|
|
||||||
return gpu_row_idxs
|
return gpu_row_idxs
|
||||||
|
|
||||||
def _reset_comm_stats(self):
|
def _reset_comm_stats(self):
|
||||||
|
@ -234,6 +290,7 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
def _prepare_rows_on_cuda(self, cpu_row_idxs: torch.Tensor) -> None:
|
||||||
"""prepare rows in cpu_row_idxs on CUDA memory
|
"""prepare rows in cpu_row_idxs on CUDA memory
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
|
cpu_row_idxs (torch.Tensor): the chunks to be placed on CUDA
|
||||||
"""
|
"""
|
||||||
|
@ -245,7 +302,9 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
invalid_idxs = torch.nonzero(mask_cpu_row_idx).squeeze(1)
|
||||||
|
|
||||||
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
self.cached_idx_map.index_fill_(0, invalid_idxs, -2)
|
||||||
evict_gpu_row_idxs = torch.argsort(self.cached_idx_map, descending=True)[:evict_num]
|
|
||||||
|
evict_gpu_row_idxs = self._find_evict_gpu_idxs(evict_num)
|
||||||
|
|
||||||
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
self.cached_idx_map.index_copy_(0, invalid_idxs, backup_idxs)
|
||||||
|
|
||||||
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
evict_info = self.cached_idx_map[evict_gpu_row_idxs]
|
||||||
|
@ -291,8 +350,16 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
self._cpu_to_cuda_numel += weight_size
|
self._cpu_to_cuda_numel += weight_size
|
||||||
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
# print(f"admit embedding weight: {weight_size*self.elem_size_in_byte/1e6:.2f} MB")
|
||||||
|
|
||||||
|
def _find_free_cuda_row(self) -> int:
|
||||||
|
if self._cuda_available_row_num == 0:
|
||||||
|
return -1
|
||||||
|
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
|
||||||
|
return candidates[0].item()
|
||||||
|
|
||||||
def _evict(self) -> int:
|
def _evict(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
deprecated
|
||||||
|
|
||||||
evict one chunk from cuda to cpu.
|
evict one chunk from cuda to cpu.
|
||||||
Returns:
|
Returns:
|
||||||
(int) : the slot id be evicted.
|
(int) : the slot id be evicted.
|
||||||
|
@ -329,15 +396,11 @@ class CachedParamMgr(torch.nn.Module):
|
||||||
# self.num_write_back_history[-1] += 1
|
# self.num_write_back_history[-1] += 1
|
||||||
return max_cpu_row_idx
|
return max_cpu_row_idx
|
||||||
|
|
||||||
def _find_free_cuda_row(self) -> int:
|
|
||||||
if self._cuda_available_row_num == 0:
|
|
||||||
return -1
|
|
||||||
candidates = torch.nonzero(self.cached_idx_map == -1).squeeze(1)
|
|
||||||
return candidates[0].item()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _admit(self, row_id: int):
|
def _admit(self, row_id: int):
|
||||||
"""
|
"""
|
||||||
|
deprecated
|
||||||
|
|
||||||
move in row_id to CUDA
|
move in row_id to CUDA
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -3,14 +3,13 @@ import torch.nn.functional as F
|
||||||
from typing import List, Optional, Iterator, Tuple
|
from typing import List, Optional, Iterator, Tuple
|
||||||
|
|
||||||
from .base_embedding import BaseEmbeddingBag
|
from .base_embedding import BaseEmbeddingBag
|
||||||
from .cache_mgr import CachedParamMgr
|
from .cache_mgr import CachedParamMgr, EvictionStrategy
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
|
||||||
num_embeddings,
|
num_embeddings,
|
||||||
embedding_dim,
|
embedding_dim,
|
||||||
padding_idx=None,
|
padding_idx=None,
|
||||||
|
@ -28,10 +27,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
warmup_ratio=0.7,
|
warmup_ratio=0.7,
|
||||||
buffer_size=50_000,
|
buffer_size=50_000,
|
||||||
pin_weight=False,
|
pin_weight=False,
|
||||||
):
|
evict_strategy: EvictionStrategy = EvictionStrategy.DATASET):
|
||||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
|
||||||
scale_grad_by_freq, sparse, mode, include_last_offset)
|
scale_grad_by_freq, sparse, mode, include_last_offset)
|
||||||
|
|
||||||
|
self.evict_strategy = evict_strategy
|
||||||
if _weight is None:
|
if _weight is None:
|
||||||
_weight = self._weight_alloc(dtype, device)
|
_weight = self._weight_alloc(dtype, device)
|
||||||
|
|
||||||
|
@ -63,7 +63,11 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||||
"""
|
"""
|
||||||
self.cache_weight_mgr = CachedParamMgr(weight, cuda_row_num, buffer_size, pin_weight)
|
self.cache_weight_mgr = CachedParamMgr(weight,
|
||||||
|
cuda_row_num,
|
||||||
|
buffer_size,
|
||||||
|
pin_weight,
|
||||||
|
evict_strategy=self.evict_strategy)
|
||||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||||
|
|
||||||
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):
|
def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None):
|
||||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \
|
||||||
ColoTensor, ColoTensorSpec
|
ColoTensor, ColoTensorSpec
|
||||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy
|
||||||
|
|
||||||
NUM_EMBED, EMBED_DIM = 10, 8
|
NUM_EMBED, EMBED_DIM = 10, 8
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
||||||
|
@ -41,6 +41,7 @@ def synthesize_1d_sparse_feature(
|
||||||
return indices, offsets
|
return indices, offsets
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
def test_cachemgr():
|
def test_cachemgr():
|
||||||
model = torch.nn.EmbeddingBag(10000, 128)
|
model = torch.nn.EmbeddingBag(10000, 128)
|
||||||
# 10 chunks, 5 in cuda
|
# 10 chunks, 5 in cuda
|
||||||
|
@ -98,14 +99,17 @@ def test_reorder_with_freq():
|
||||||
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
||||||
|
|
||||||
|
|
||||||
def test_freq_aware_embed():
|
@pytest.mark.parametrize('use_LFU', [True, False])
|
||||||
|
def test_freq_aware_embed(use_LFU: bool):
|
||||||
device = torch.device('cuda', 0)
|
device = torch.device('cuda', 0)
|
||||||
|
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
||||||
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
model = FreqAwareEmbeddingBag(NUM_EMBED,
|
||||||
EMBED_DIM,
|
EMBED_DIM,
|
||||||
mode='mean',
|
mode='mean',
|
||||||
include_last_offset=True,
|
include_last_offset=True,
|
||||||
cuda_row_num=BATCH_SIZE * 2,
|
cuda_row_num=BATCH_SIZE * 2,
|
||||||
ids_freq_mapping=None).to(device)
|
ids_freq_mapping=None,
|
||||||
|
evict_strategy=evict_strategy).to(device)
|
||||||
|
|
||||||
assert model.weight.shape[0] == NUM_EMBED
|
assert model.weight.shape[0] == NUM_EMBED
|
||||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||||
|
@ -231,6 +235,5 @@ def test_parallel_freq_aware_embed(world_size):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_cachemgr()
|
test_freq_aware_embed(True)
|
||||||
# test_freq_aware_embed()
|
|
||||||
# test_parallel_freq_aware_embed(2)
|
# test_parallel_freq_aware_embed(2)
|
||||||
|
|
Loading…
Reference in New Issue