mirror of https://github.com/hpcaitech/ColossalAI
Add FreqAwareEmbeddingBag (#1421)
parent
6df3e19be9
commit
d209aff684
|
@ -1,4 +1,5 @@
|
|||
from .cache_mgr import CachedParamMgr
|
||||
from .copyer import LimitBuffIndexCopyer
|
||||
from .freq_aware_embedding import FreqAwareEmbeddingBag
|
||||
|
||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer']
|
||||
__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag']
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List, Optional, Iterator, Tuple
|
||||
|
||||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class FreqAwareEmbeddingBag(BaseEmbeddingBag):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs):
|
||||
super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs)
|
||||
self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype)
|
||||
|
||||
def preprocess(self,
|
||||
cuda_row_num: int,
|
||||
ids_freq_mapping: Optional[List[int]] = None,
|
||||
warmup_ratio=0.7,
|
||||
buffer_size=50_000):
|
||||
"""
|
||||
Called after initialized.
|
||||
Reorder the weight rows according to the ids_freq_mapping.
|
||||
Then, let the weights of the Module be managed by a CachedParamMgr.
|
||||
Args:
|
||||
cuda_row_num (int): number of rows can be hosted in CUDA memory
|
||||
ids_freq_mapping (List[int]): a list, idx is id number, value is freq
|
||||
warmup_ratio (float): the amount of rows preloaded in cuda cache
|
||||
"""
|
||||
self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size)
|
||||
self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio)
|
||||
|
||||
def forward(self, indices, offsets=None, per_sample_weights=None):
|
||||
with torch.no_grad():
|
||||
reorder_ids = self.cache_weight_mgr.prepare_ids(indices)
|
||||
|
||||
embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse,
|
||||
per_sample_weights, self.include_last_offset, self.padding_idx)
|
||||
|
||||
return embeddings
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
assert self.cache_weight_mgr is not None
|
||||
return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
yield 'weight', self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield self.cache_weight_mgr.cuda_cached_weight
|
||||
|
||||
@property
|
||||
def num_hits_history(self):
|
||||
return self.cache_weight_mgr.num_hits_history
|
||||
|
||||
@property
|
||||
def num_miss_history(self):
|
||||
return self.cache_weight_mgr.num_miss_history
|
||||
|
||||
@property
|
||||
def num_write_back_history(self):
|
||||
return self.cache_weight_mgr.num_write_back_history
|
||||
|
||||
@property
|
||||
def swap_in_bandwidth(self):
|
||||
if self.cache_weight_mgr._cpu_to_cuda_numel > 0:
|
||||
return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cpu_to_cuda_elpase
|
||||
else:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def swap_out_bandwidth(self):
|
||||
if self.cache_weight_mgr._cuda_to_cpu_numel > 0:
|
||||
return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \
|
||||
self.cache_weight_mgr._cuda_to_cpu_elapse
|
||||
return 0
|
||||
|
||||
@property
|
||||
def input_id_percent_in_load_chunk(self):
|
||||
return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100
|
|
@ -7,12 +7,26 @@ import numpy as np
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 100, 8
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
BATCH_SIZE = 8
|
||||
|
||||
|
||||
def synthesize_1d_sparse_feature(
|
||||
batch_size,
|
||||
num_embed,
|
||||
device,
|
||||
):
|
||||
indices_in_batch = batch_size * 2
|
||||
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
|
||||
offsets = torch.from_numpy(
|
||||
np.array([
|
||||
0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
|
||||
])).to(device).long()
|
||||
return indices, offsets
|
||||
|
||||
|
||||
def test_cachemgr():
|
||||
model = torch.nn.EmbeddingBag(10000, 128)
|
||||
# 10 chunks, 5 in cuda
|
||||
|
@ -70,6 +84,50 @@ def test_reorder_with_freq():
|
|||
f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
||||
|
||||
|
||||
def test_freq_aware_embed():
|
||||
device = torch.device('cuda', 0)
|
||||
model = FreqAwareEmbeddingBag(
|
||||
NUM_EMBED,
|
||||
EMBED_DIM,
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
).to(device)
|
||||
model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None)
|
||||
|
||||
assert model.weight.shape[0] == NUM_EMBED
|
||||
ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
|
||||
mode='mean',
|
||||
include_last_offset=True,
|
||||
freeze=False)
|
||||
|
||||
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
||||
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
||||
|
||||
for i in range(5):
|
||||
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
|
||||
res = model(indices, offsets)
|
||||
ref_res = ref_model(indices, offsets)
|
||||
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
|
||||
|
||||
grad = torch.rand_like(res)
|
||||
# comparing gradient here is nontrivial
|
||||
res.backward(grad)
|
||||
ref_res.backward(grad)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
ref_optimizer.step()
|
||||
ref_optimizer.zero_grad()
|
||||
|
||||
model.cache_weight_mgr.flush()
|
||||
model_weight = model.weight.detach().to(device)
|
||||
ref_weight = ref_model.weight.detach()
|
||||
assert torch.allclose(model_weight, ref_weight), \
|
||||
f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_freq_aware_embed()
|
||||
# test_chunkmgr_admit()
|
||||
|
|
Loading…
Reference in New Issue