import pytest from functools import partial import numpy as np import random import torch import torch.multiprocessing as mp import colossalai from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \ ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelFreqAwareEmbeddingBagTablewiseSpiltCache from typing import List NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 def set_seed(seed): """ To achieve reproducible results, it's necessary to fix random seeds """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def synthesize_1d_sparse_feature( batch_size, num_embed, device, ): indices_in_batch = batch_size * 2 indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long) offsets = torch.from_numpy( np.array([ 0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch ])).to(device).long() return indices, offsets @pytest.mark.skip def test_cachemgr(): model = torch.nn.EmbeddingBag(10000, 128) # 10 chunks, 5 in cuda mgr = CachedParamMgr(model.weight.detach(), 5) assert mgr.cuda_row_num == 5 mgr._admit(1) assert not mgr._chunk_in_cuda(2) assert mgr._chunk_in_cuda(1) # print(mgr.cached_chunk_table) mgr._admit(8) # now 3 chunk is available assert mgr.cuda_available_chunk_num == 3 mgr._evict() assert mgr.cuda_available_chunk_num == 4 mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0)) mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0)) # print(mgr.cached_chunk_table) # mgr.print_comm_stats() mgr.flush() assert mgr.cuda_available_chunk_num == 5 def test_reorder_with_freq(): num_embed = 100 chunk_size = 1 num_chunk = 5 idx_map = torch.randint(10000, size=(num_embed,)) sorted_idx = torch.argsort(idx_map, descending=True).tolist() chunkid, offset_in_chunk = [], [] for i in range(num_embed): idx = sorted_idx.index(i) chunkid.append(idx // chunk_size) offset_in_chunk.append(idx % chunk_size) dev = torch.device('cuda') chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev) offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev) weight = torch.rand(num_embed, 2) mgr = CachedParamMgr(weight, num_chunk, use_cpu_caching=dev.type == 'cpu') mgr.reorder(idx_map) indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev)) mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor') mgr_offsets = torch.remainder(indices, chunk_size) assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}" assert torch.allclose(offset_in_chunk, mgr_offsets), \ f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}" @pytest.mark.parametrize('use_LFU', [True, False]) def test_freq_aware_embed(use_LFU: bool): device = torch.device('cuda', 0) evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET model = FreqAwareEmbeddingBag(NUM_EMBED, EMBED_DIM, mode='mean', include_last_offset=True, cuda_row_num=BATCH_SIZE * 2, ids_freq_mapping=None, evict_strategy=evict_strategy).to(device) assert model.weight.shape[0] == NUM_EMBED ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device), mode='mean', include_last_offset=True, freeze=False) assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) for i in range(5): indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device) res = model(indices, offsets) ref_res = ref_model(indices, offsets) assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}" grad = torch.rand_like(res) # comparing gradient here is nontrivial res.backward(grad) ref_res.backward(grad) optimizer.step() optimizer.zero_grad() ref_optimizer.step() ref_optimizer.zero_grad() model.cache_weight_mgr.flush() model_weight = model.weight.detach().to(device) ref_weight = ref_model.weight.detach() assert torch.allclose(model_weight, ref_weight), \ f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" @pytest.mark.parametrize('init_freq', [True, False]) def test_lfu_strategy(init_freq: bool): # minimal test to check behavior Bag = FreqAwareEmbeddingBag(5, 5, cuda_row_num=3, buffer_size=0, pin_weight=True, ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None, warmup_ratio=1.0, evict_strategy=EvictionStrategy.LFU) # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map) offsets = torch.tensor([0], device="cuda:0") # prepare frequency learning info: Bag.forward(torch.tensor([2], device="cuda:0"), offsets) Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0], device="cuda:0"), offsets) Bag.forward(torch.tensor([0], device="cuda:0"), offsets) Bag.forward(torch.tensor([0], device="cuda:0"), offsets) Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # check strategy Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets) Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1 Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3 Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \ "LFU strategy behavior failed" def gather_tensor(tensor, rank, world_size): gather_list = [] if rank == 0: gather_list = [torch.empty_like(tensor) for _ in range(world_size)] torch.distributed.gather(tensor, gather_list, dst=0) return gather_list def run_parallel_freq_aware_embed_tablewise(rank, world_size): if world_size != 2: return device = torch.device('cuda', torch.cuda.current_device()) # initialize weight # 3 feature tables. idx: 0~5, 6~10, 11~17 weight_tables = torch.rand(18,5) weight_table1 = weight_tables[0:6] weight_table2 = weight_tables[6:11] weight_table3 = weight_tables[11:18] embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu())) embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu())) embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu())) if rank == 0: _weight = torch.cat([weight_table1, weight_table2],0) else: _weight = weight_table3 model = ParallelFreqAwareEmbeddingBagTablewise( embedding_bag_config_list, embedding_dim=5, _weight=_weight, include_last_offset=True, cuda_row_num=8, buffer_size=0, evict_strategy=EvictionStrategy.LFU, ) # explain ''' batch feature 1 feature 2 feature 3 input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] input2 [1,5] [6,8] [11] ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format ''' res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) if rank == 0: fake_grad = rand_grad[0:2] else : fake_grad = rand_grad[2:] res.backward(fake_grad) optimizer.step() optimizer.zero_grad() # check correctness if rank == 0: ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(), include_last_offset=True, freeze=False).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) ref_fake_grad = torch.cat(rand_grad.split(5,1),0) ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) ref_res.backward(ref_fake_grad) ref_optimizer.step() ref_optimizer.zero_grad() model.cache_weight_mgr.flush() recover_weight = model.cache_weight_mgr.weight.to(device) ref_weight = ref_model.weight.detach()[:11] assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}" def run_parallel_freq_aware_embed_columnwise(rank, world_size): device = torch.device('cuda', torch.cuda.current_device()) num_embed = 100 embed_dim = 16 batch_size = 4 set_seed(4321) weight = torch.rand(num_embed, embed_dim) coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None) # initialize the tensor spec for the embedding weight parameter, # which is an ColoParameter. coloweight.set_process_group(ProcessGroup(tp_degree=world_size)) coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, include_last_offset=True, freeze=False, cuda_row_num=batch_size * 2, ) assert model.cache_weight_mgr.weight.device.type == 'cpu' assert model.cache_weight_mgr.cuda_cached_weight.requires_grad weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}") assert torch.allclose(weight_in_rank, model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}" optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) if rank == 0: ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), include_last_offset=True, freeze=False).to(device) ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) set_seed(4321) for i in range(5): indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) res = model(indices, offsets) grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] res.backward(grad_in_rank) optimizer.step() optimizer.zero_grad() res_list = gather_tensor(res.detach(), rank, world_size) if rank == 0: ref_res = ref_model(indices, offsets) recover_res = torch.cat(res_list, dim=0) assert torch.allclose(ref_res, recover_res) ref_res.backward(grad) ref_optimizer.step() ref_optimizer.zero_grad() model.cache_weight_mgr.flush() weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size) if rank == 0: recover_weight = torch.cat(weight_list, dim=1) assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') # run_parallel_freq_aware_embed_columnwise(rank, world_size) run_parallel_freq_aware_embed_tablewise(rank, world_size) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_parallel_freq_aware_embed(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': # test_freq_aware_embed(True) test_parallel_freq_aware_embed(2) # test_lfu_strategy(False)