Add FreqAwareEmbeddingBag (#1421)

pull/1423/head
Jiarui Fang 2022-08-09 16:26:12 +08:00 committed by GitHub
parent 6df3e19be9
commit d209aff684
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 145 additions and 3 deletions

View File

@ -1,4 +1,5 @@
from .cache_mgr import CachedParamMgr
from .copyer import LimitBuffIndexCopyer
from .freq_aware_embedding import FreqAwareEmbeddingBag
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer']
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag']

View File

@ -0,0 +1,83 @@
import torch
import torch.nn.functional as F
from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr
from torch.nn.parameter import Parameter
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs):
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs)
self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype)
def preprocess(self,
cuda_row_num: int,
ids_freq_mapping: Optional[List[int]] = None,
warmup_ratio=0.7,
buffer_size=50_000):
"""
Called after initialized.
Reorder the weight rows according to the ids_freq_mapping.
Then, let the weights of the Module be managed by a CachedParamMgr.
Args:
cuda_row_num (int): number of rows can be hosted in CUDA memory
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
warmup_ratio (float): the amount of rows preloaded in cuda cache
"""
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size)
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
def forward(self, indices, offsets=None, per_sample_weights=None):
with torch.no_grad():
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
per_sample_weights, self.include_last_offset, self.padding_idx)
return embeddings
@property
def weight(self):
assert self.cache_weight_mgr is not None
return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
yield self.cache_weight_mgr.cuda_cached_weight
@property
def num_hits_history(self):
return self.cache_weight_mgr.num_hits_history
@property
def num_miss_history(self):
return self.cache_weight_mgr.num_miss_history
@property
def num_write_back_history(self):
return self.cache_weight_mgr.num_write_back_history
@property
def swap_in_bandwidth(self):
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cpu_to_cuda_elpase
else:
return 0
@property
def swap_out_bandwidth(self):
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
self.cache_weight_mgr._cuda_to_cpu_elapse
return 0
@property
def input_id_percent_in_load_chunk(self):
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100

View File

@ -7,12 +7,26 @@ import numpy as np
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.nn._ops.cache_embedding import CachedParamMgr
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 100, 8
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8
def synthesize_1d_sparse_feature(
batch_size,
num_embed,
device,
):
indices_in_batch = batch_size * 2
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
offsets = torch.from_numpy(
np.array([
0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
])).to(device).long()
return indices, offsets
def test_cachemgr():
model = torch.nn.EmbeddingBag(10000, 128)
# 10 chunks, 5 in cuda
@ -70,6 +84,50 @@ def test_reorder_with_freq():
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
def test_freq_aware_embed():
device = torch.device('cuda', 0)
model = FreqAwareEmbeddingBag(
NUM_EMBED,
EMBED_DIM,
mode='mean',
include_last_offset=True,
).to(device)
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
assert model.weight.shape[0] == NUM_EMBED
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
mode='mean',
include_last_offset=True,
freeze=False)
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
for i in range(5):
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
res = model(indices, offsets)
ref_res = ref_model(indices, offsets)
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
grad = torch.rand_like(res)
# comparing gradient here is nontrivial
res.backward(grad)
ref_res.backward(grad)
optimizer.step()
optimizer.zero_grad()
ref_optimizer.step()
ref_optimizer.zero_grad()
model.cache_weight_mgr.flush()
model_weight = model.weight.detach().to(device)
ref_weight = ref_model.weight.detach()
assert torch.allclose(model_weight, ref_weight), \
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
if __name__ == '__main__':
# test_freq_aware_embed()
# test_chunkmgr_admit()