diff --git a/colossalai/nn/_ops/cache_embedding/__init__.py b/colossalai/nn/_ops/cache_embedding/__init__.py index 4693e9055..0510e89f6 100644 --- a/colossalai/nn/_ops/cache_embedding/__init__.py +++ b/colossalai/nn/_ops/cache_embedding/__init__.py @@ -1,4 +1,5 @@ from .cache_mgr import CachedParamMgr from .copyer import LimitBuffIndexCopyer +from .freq_aware_embedding import FreqAwareEmbeddingBag -__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer'] \ No newline at end of file +__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag'] diff --git a/colossalai/nn/_ops/cache_embedding/freq_aware_embedding.py b/colossalai/nn/_ops/cache_embedding/freq_aware_embedding.py new file mode 100644 index 000000000..95f6996fa --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/freq_aware_embedding.py @@ -0,0 +1,83 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple + +from .base_embedding import BaseEmbeddingBag +from .cache_mgr import CachedParamMgr +from torch.nn.parameter import Parameter + + +class FreqAwareEmbeddingBag(BaseEmbeddingBag): + + def __init__(self, num_embeddings, embedding_dim, dtype=None, *args, **kwargs): + super(FreqAwareEmbeddingBag, self).__init__(num_embeddings, embedding_dim, *args, **kwargs) + self._weight = torch.randn(self.num_embeddings, self.embedding_dim, device='cpu', dtype=dtype) + + def preprocess(self, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio=0.7, + buffer_size=50_000): + """ + Called after initialized. + Reorder the weight rows according to the ids_freq_mapping. + Then, let the weights of the Module be managed by a CachedParamMgr. + Args: + cuda_row_num (int): number of rows can be hosted in CUDA memory + ids_freq_mapping (List[int]): a list, idx is id number, value is freq + warmup_ratio (float): the amount of rows preloaded in cuda cache + """ + self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size) + self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + + def forward(self, indices, offsets=None, per_sample_weights=None): + with torch.no_grad(): + reorder_ids = self.cache_weight_mgr.prepare_ids(indices) + + embeddings = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + + return embeddings + + @property + def weight(self): + assert self.cache_weight_mgr is not None + return self.cache_weight_mgr.cpu_weight.narrow(0, 0, self.num_embeddings) + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield 'weight', self.cache_weight_mgr.cuda_cached_weight + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield self.cache_weight_mgr.cuda_cached_weight + + @property + def num_hits_history(self): + return self.cache_weight_mgr.num_hits_history + + @property + def num_miss_history(self): + return self.cache_weight_mgr.num_miss_history + + @property + def num_write_back_history(self): + return self.cache_weight_mgr.num_write_back_history + + @property + def swap_in_bandwidth(self): + if self.cache_weight_mgr._cpu_to_cuda_numel > 0: + return self.cache_weight_mgr._cpu_to_cuda_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cpu_to_cuda_elpase + else: + return 0 + + @property + def swap_out_bandwidth(self): + if self.cache_weight_mgr._cuda_to_cpu_numel > 0: + return self.cache_weight_mgr._cuda_to_cpu_numel * self.cache_weight_mgr.elem_size_in_byte / 1e6 / \ + self.cache_weight_mgr._cuda_to_cpu_elapse + return 0 + + @property + def input_id_percent_in_load_chunk(self): + return 0 # np.mean(self.cache_weight_mgr.input_id_percent_in_load_chunk) * 100 diff --git a/tests/test_tensor/ops/test_cache_embedding.py b/tests/test_tensor/ops/test_cache_embedding.py index 6546af361..688d59b91 100644 --- a/tests/test_tensor/ops/test_cache_embedding.py +++ b/tests/test_tensor/ops/test_cache_embedding.py @@ -7,12 +7,26 @@ import numpy as np from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from colossalai.nn._ops.cache_embedding import CachedParamMgr +from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag -NUM_EMBED, EMBED_DIM = 100, 8 +NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 +def synthesize_1d_sparse_feature( + batch_size, + num_embed, + device, +): + indices_in_batch = batch_size * 2 + indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) + offsets = torch.from_numpy( + np.array([ + 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch + ])).to(device).long() + return indices, offsets + + def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda @@ -70,6 +84,50 @@ def test_reorder_with_freq(): f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" +def test_freq_aware_embed(): + device = torch.device('cuda', 0) + model = FreqAwareEmbeddingBag( + NUM_EMBED, + EMBED_DIM, + mode='mean', + include_last_offset=True, + ).to(device) + model.preprocess(cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None) + + assert model.weight.shape[0] == NUM_EMBED + ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), + mode='mean', + include_last_offset=True, + freeze=False) + + assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) + res = model(indices, offsets) + ref_res = ref_model(indices, offsets) + assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" + + grad = torch.rand_like(res) + # comparing gradient here is nontrivial + res.backward(grad) + ref_res.backward(grad) + optimizer.step() + optimizer.zero_grad() + + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + model_weight = model.weight.detach().to(device) + ref_weight = ref_model.weight.detach() + assert torch.allclose(model_weight, ref_weight), \ + f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" + + if __name__ == '__main__': # test_freq_aware_embed() # test_chunkmgr_admit()