2022-08-10 05:44:30 +00:00
|
|
|
import random
|
2023-04-06 06:51:35 +00:00
|
|
|
from typing import List
|
2022-08-09 07:17:17 +00:00
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
2022-08-11 05:43:24 +00:00
|
|
|
import torch
|
|
|
|
|
2022-08-10 05:44:30 +00:00
|
|
|
import colossalai
|
2023-09-11 08:24:28 +00:00
|
|
|
from colossalai.legacy.nn.parallel.layers import (
|
2023-04-06 06:51:35 +00:00
|
|
|
CachedEmbeddingBag,
|
|
|
|
CachedParamMgr,
|
|
|
|
EvictionStrategy,
|
|
|
|
ParallelCachedEmbeddingBag,
|
|
|
|
ParallelCachedEmbeddingBagTablewise,
|
|
|
|
TablewiseEmbeddingBagConfig,
|
|
|
|
)
|
2023-09-18 08:31:06 +00:00
|
|
|
from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
|
|
|
from colossalai.tensor import ColoTensor
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
|
2022-08-09 07:17:17 +00:00
|
|
|
|
2022-08-09 08:26:12 +00:00
|
|
|
NUM_EMBED, EMBED_DIM = 10, 8
|
2022-08-09 07:17:17 +00:00
|
|
|
BATCH_SIZE = 8
|
|
|
|
|
|
|
|
|
2022-08-10 05:44:30 +00:00
|
|
|
def set_seed(seed):
|
|
|
|
"""
|
|
|
|
To achieve reproducible results, it's necessary to fix random seeds
|
|
|
|
"""
|
|
|
|
random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
|
|
|
|
|
2022-08-09 08:26:12 +00:00
|
|
|
def synthesize_1d_sparse_feature(
|
|
|
|
batch_size,
|
|
|
|
num_embed,
|
|
|
|
device,
|
|
|
|
):
|
|
|
|
indices_in_batch = batch_size * 2
|
|
|
|
indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
|
2023-09-19 06:20:26 +00:00
|
|
|
offsets = (
|
|
|
|
torch.from_numpy(
|
|
|
|
np.array(
|
|
|
|
[
|
|
|
|
0,
|
|
|
|
*np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))),
|
|
|
|
indices_in_batch,
|
|
|
|
]
|
|
|
|
)
|
|
|
|
)
|
|
|
|
.to(device)
|
|
|
|
.long()
|
|
|
|
)
|
2022-08-09 08:26:12 +00:00
|
|
|
return indices, offsets
|
|
|
|
|
|
|
|
|
2022-08-24 09:37:22 +00:00
|
|
|
@pytest.mark.skip
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2022-08-09 07:17:17 +00:00
|
|
|
def test_cachemgr():
|
|
|
|
model = torch.nn.EmbeddingBag(10000, 128)
|
|
|
|
# 10 chunks, 5 in cuda
|
2022-08-23 09:38:24 +00:00
|
|
|
mgr = CachedParamMgr(model.weight.detach(), 5)
|
2022-08-09 07:17:17 +00:00
|
|
|
assert mgr.cuda_row_num == 5
|
|
|
|
|
|
|
|
mgr._admit(1)
|
|
|
|
assert not mgr._chunk_in_cuda(2)
|
|
|
|
assert mgr._chunk_in_cuda(1)
|
|
|
|
|
|
|
|
# print(mgr.cached_chunk_table)
|
|
|
|
mgr._admit(8)
|
|
|
|
|
|
|
|
# now 3 chunk is available
|
|
|
|
assert mgr.cuda_available_chunk_num == 3
|
|
|
|
|
|
|
|
mgr._evict()
|
|
|
|
assert mgr.cuda_available_chunk_num == 4
|
|
|
|
|
|
|
|
mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0))
|
|
|
|
mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0))
|
|
|
|
# print(mgr.cached_chunk_table)
|
|
|
|
# mgr.print_comm_stats()
|
|
|
|
|
|
|
|
mgr.flush()
|
|
|
|
assert mgr.cuda_available_chunk_num == 5
|
|
|
|
|
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2022-08-09 07:17:17 +00:00
|
|
|
def test_reorder_with_freq():
|
|
|
|
num_embed = 100
|
|
|
|
chunk_size = 1
|
|
|
|
num_chunk = 5
|
|
|
|
|
2022-08-23 09:38:24 +00:00
|
|
|
idx_map = torch.randint(10000, size=(num_embed,))
|
|
|
|
sorted_idx = torch.argsort(idx_map, descending=True).tolist()
|
2022-08-09 07:17:17 +00:00
|
|
|
chunkid, offset_in_chunk = [], []
|
|
|
|
for i in range(num_embed):
|
|
|
|
idx = sorted_idx.index(i)
|
|
|
|
chunkid.append(idx // chunk_size)
|
|
|
|
offset_in_chunk.append(idx % chunk_size)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
dev = torch.device("cuda")
|
2022-08-30 06:50:02 +00:00
|
|
|
chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
|
|
|
|
offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
|
2022-08-09 07:17:17 +00:00
|
|
|
|
|
|
|
weight = torch.rand(num_embed, 2)
|
2022-09-22 03:16:25 +00:00
|
|
|
mgr = CachedParamMgr(weight, num_chunk)
|
2022-08-09 07:17:17 +00:00
|
|
|
|
|
|
|
mgr.reorder(idx_map)
|
|
|
|
|
2022-08-30 06:50:02 +00:00
|
|
|
indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
|
2023-09-19 06:20:26 +00:00
|
|
|
mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode="floor")
|
2022-08-09 07:17:17 +00:00
|
|
|
mgr_offsets = torch.remainder(indices, chunk_size)
|
|
|
|
assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(offset_in_chunk, mgr_offsets), f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"
|
2022-08-09 07:17:17 +00:00
|
|
|
|
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("use_LFU", [True, False])
|
2022-08-24 09:37:22 +00:00
|
|
|
def test_freq_aware_embed(use_LFU: bool):
|
2023-09-19 06:20:26 +00:00
|
|
|
device = torch.device("cuda", 0)
|
2022-08-24 09:37:22 +00:00
|
|
|
evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
|
2023-09-19 06:20:26 +00:00
|
|
|
model = CachedEmbeddingBag(
|
|
|
|
NUM_EMBED,
|
|
|
|
EMBED_DIM,
|
|
|
|
mode="mean",
|
|
|
|
include_last_offset=True,
|
|
|
|
cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
|
|
|
|
ids_freq_mapping=None,
|
|
|
|
evict_strategy=evict_strategy,
|
|
|
|
).to(device)
|
2022-08-09 08:26:12 +00:00
|
|
|
|
|
|
|
assert model.weight.shape[0] == NUM_EMBED
|
2023-09-19 06:20:26 +00:00
|
|
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(
|
|
|
|
model.weight.detach().to(device), mode="mean", include_last_offset=True, freeze=False
|
|
|
|
)
|
2022-08-09 08:26:12 +00:00
|
|
|
|
|
|
|
assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))
|
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
|
|
|
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
for i in range(5):
|
|
|
|
indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
|
|
|
|
res = model(indices, offsets)
|
|
|
|
ref_res = ref_model(indices, offsets)
|
|
|
|
assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"
|
|
|
|
|
|
|
|
grad = torch.rand_like(res)
|
|
|
|
# comparing gradient here is nontrivial
|
|
|
|
res.backward(grad)
|
|
|
|
ref_res.backward(grad)
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
ref_optimizer.step()
|
|
|
|
ref_optimizer.zero_grad()
|
|
|
|
|
|
|
|
model.cache_weight_mgr.flush()
|
|
|
|
model_weight = model.weight.detach().to(device)
|
|
|
|
ref_weight = ref_model.weight.detach()
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(
|
|
|
|
model_weight, ref_weight
|
|
|
|
), f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"
|
2022-08-09 08:26:12 +00:00
|
|
|
|
2022-08-29 06:22:07 +00:00
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2023-09-19 06:20:26 +00:00
|
|
|
@parameterize("init_freq", [True, False])
|
2022-08-29 06:22:07 +00:00
|
|
|
def test_lfu_strategy(init_freq: bool):
|
2022-08-25 05:08:46 +00:00
|
|
|
# minimal test to check behavior
|
2023-09-19 06:20:26 +00:00
|
|
|
Bag = CachedEmbeddingBag(
|
|
|
|
5,
|
|
|
|
5,
|
|
|
|
cache_ratio=3 / 5,
|
|
|
|
buffer_size=0,
|
|
|
|
pin_weight=True,
|
|
|
|
ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
|
|
|
|
warmup_ratio=1.0,
|
|
|
|
evict_strategy=EvictionStrategy.LFU,
|
|
|
|
)
|
2022-08-29 06:22:07 +00:00
|
|
|
|
|
|
|
# print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
|
|
|
|
offsets = torch.tensor([0], device="cuda:0")
|
2022-08-25 05:08:46 +00:00
|
|
|
|
|
|
|
# prepare frequency learning info:
|
2022-08-29 06:22:07 +00:00
|
|
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
|
2022-08-25 05:08:46 +00:00
|
|
|
|
|
|
|
# check strategy
|
2022-08-29 06:22:07 +00:00
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
|
|
|
Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
|
2023-09-19 06:20:26 +00:00
|
|
|
Bag.forward(torch.tensor([3], device="cuda:0"), offsets) # miss, evict 1
|
|
|
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
|
|
|
Bag.forward(torch.tensor([4], device="cuda:0"), offsets) # miss, evict 3
|
|
|
|
Bag.forward(torch.tensor([2], device="cuda:0"), offsets) # hit
|
|
|
|
Bag.forward(torch.tensor([0], device="cuda:0"), offsets) # hit
|
2022-08-25 05:08:46 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(
|
|
|
|
torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])
|
|
|
|
), "LFU strategy behavior failed"
|
2022-08-29 06:22:07 +00:00
|
|
|
|
|
|
|
|
2022-08-10 05:44:30 +00:00
|
|
|
def gather_tensor(tensor, rank, world_size):
|
|
|
|
gather_list = []
|
|
|
|
if rank == 0:
|
|
|
|
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
|
|
|
|
|
|
|
torch.distributed.gather(tensor, gather_list, dst=0)
|
|
|
|
return gather_list
|
|
|
|
|
|
|
|
|
2022-09-01 09:55:41 +00:00
|
|
|
def run_parallel_freq_aware_embed_tablewise(rank, world_size):
|
|
|
|
if world_size != 2:
|
|
|
|
return
|
2023-09-19 06:20:26 +00:00
|
|
|
device = torch.device("cuda", torch.cuda.current_device())
|
2022-09-01 09:55:41 +00:00
|
|
|
|
|
|
|
# initialize weight
|
|
|
|
# 3 feature tables. idx: 0~5, 6~10, 11~17
|
2022-09-06 02:41:20 +00:00
|
|
|
weight_tables = torch.rand(18, 5)
|
2022-09-05 07:12:53 +00:00
|
|
|
weight_table1 = weight_tables[0:6]
|
|
|
|
weight_table2 = weight_tables[6:11]
|
|
|
|
weight_table3 = weight_tables[11:18]
|
2022-09-01 09:55:41 +00:00
|
|
|
embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
|
2022-09-06 02:41:20 +00:00
|
|
|
embedding_bag_config_list.append(
|
2023-09-19 06:20:26 +00:00
|
|
|
TablewiseEmbeddingBagConfig(
|
|
|
|
num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu()
|
|
|
|
)
|
|
|
|
)
|
2022-09-06 02:41:20 +00:00
|
|
|
embedding_bag_config_list.append(
|
2023-09-19 06:20:26 +00:00
|
|
|
TablewiseEmbeddingBagConfig(
|
|
|
|
num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu()
|
|
|
|
)
|
|
|
|
)
|
2022-09-06 02:41:20 +00:00
|
|
|
embedding_bag_config_list.append(
|
2023-09-19 06:20:26 +00:00
|
|
|
TablewiseEmbeddingBagConfig(
|
|
|
|
num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu()
|
|
|
|
)
|
|
|
|
)
|
2022-09-05 07:12:53 +00:00
|
|
|
if rank == 0:
|
2022-09-06 02:41:20 +00:00
|
|
|
_weight = torch.cat([weight_table1, weight_table2], 0)
|
2022-09-05 07:12:53 +00:00
|
|
|
else:
|
|
|
|
_weight = weight_table3
|
2022-10-13 14:22:27 +00:00
|
|
|
model = ParallelCachedEmbeddingBagTablewise(
|
2022-09-01 09:55:41 +00:00
|
|
|
embedding_bag_config_list,
|
|
|
|
embedding_dim=5,
|
2022-09-05 07:12:53 +00:00
|
|
|
_weight=_weight,
|
|
|
|
include_last_offset=True,
|
2022-09-20 06:33:04 +00:00
|
|
|
cache_ratio=0.5,
|
2022-09-05 07:12:53 +00:00
|
|
|
buffer_size=0,
|
2022-09-01 09:55:41 +00:00
|
|
|
evict_strategy=EvictionStrategy.LFU,
|
|
|
|
)
|
2022-09-05 07:12:53 +00:00
|
|
|
# explain
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-09-01 09:55:41 +00:00
|
|
|
batch feature 1 feature 2 feature 3
|
|
|
|
input0 [1,2,3] [6,7] []
|
|
|
|
input1 [] [9] [13,15]
|
|
|
|
input2 [1,5] [6,8] [11]
|
2023-04-06 06:51:35 +00:00
|
|
|
↑ ↑ ↑
|
2022-09-01 09:55:41 +00:00
|
|
|
rank 0 rank 0 rank 1
|
|
|
|
in KJT format
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
|
|
|
res = model(
|
|
|
|
torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
|
|
|
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
|
|
|
|
already_split_along_rank=False,
|
|
|
|
)
|
2022-09-01 09:55:41 +00:00
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
|
|
|
|
rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
|
|
|
|
if rank == 0:
|
|
|
|
fake_grad = rand_grad[0:2]
|
2022-09-06 02:41:20 +00:00
|
|
|
else:
|
2022-09-01 09:55:41 +00:00
|
|
|
fake_grad = rand_grad[2:]
|
|
|
|
res.backward(fake_grad)
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2022-09-06 02:41:20 +00:00
|
|
|
# check correctness
|
2022-09-01 09:55:41 +00:00
|
|
|
if rank == 0:
|
2023-09-19 06:20:26 +00:00
|
|
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(
|
|
|
|
weight_tables.detach().clone(), include_last_offset=True, freeze=False
|
|
|
|
).to(device)
|
2022-09-01 09:55:41 +00:00
|
|
|
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
|
2022-09-06 02:41:20 +00:00
|
|
|
ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
|
2023-09-19 06:20:26 +00:00
|
|
|
ref_res = ref_model(
|
|
|
|
torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
|
|
|
|
torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
|
|
|
|
)
|
2022-09-05 07:12:53 +00:00
|
|
|
ref_res.backward(ref_fake_grad)
|
2022-09-01 09:55:41 +00:00
|
|
|
ref_optimizer.step()
|
|
|
|
ref_optimizer.zero_grad()
|
2022-09-06 02:41:20 +00:00
|
|
|
|
2022-09-05 07:12:53 +00:00
|
|
|
model.cache_weight_mgr.flush()
|
|
|
|
recover_weight = model.cache_weight_mgr.weight.to(device)
|
|
|
|
ref_weight = ref_model.weight.detach()[:11]
|
|
|
|
assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
|
2022-09-01 09:55:41 +00:00
|
|
|
|
2022-09-06 02:41:20 +00:00
|
|
|
|
2022-09-01 09:55:41 +00:00
|
|
|
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
|
2023-09-19 06:20:26 +00:00
|
|
|
device = torch.device("cuda", torch.cuda.current_device())
|
2022-08-10 05:44:30 +00:00
|
|
|
|
|
|
|
num_embed = 100
|
|
|
|
embed_dim = 16
|
|
|
|
batch_size = 4
|
|
|
|
|
|
|
|
set_seed(4321)
|
|
|
|
weight = torch.rand(num_embed, embed_dim)
|
2022-08-12 07:55:46 +00:00
|
|
|
coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
|
2022-08-10 05:44:30 +00:00
|
|
|
|
2022-08-10 06:31:53 +00:00
|
|
|
# initialize the tensor spec for the embedding weight parameter,
|
|
|
|
# which is an ColoParameter.
|
2022-08-12 07:55:46 +00:00
|
|
|
coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
|
2022-08-10 06:31:53 +00:00
|
|
|
coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
|
2022-10-13 14:22:27 +00:00
|
|
|
model = ParallelCachedEmbeddingBag.from_pretrained(
|
2022-09-06 02:41:20 +00:00
|
|
|
coloweight,
|
|
|
|
include_last_offset=True,
|
|
|
|
freeze=False,
|
2022-09-20 06:33:04 +00:00
|
|
|
cache_ratio=batch_size * 2 / num_embed,
|
2022-09-06 02:41:20 +00:00
|
|
|
)
|
2022-08-10 05:44:30 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
assert model.cache_weight_mgr.weight.device.type == "cpu"
|
2022-08-10 05:44:30 +00:00
|
|
|
assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
|
|
|
|
weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
|
2022-08-12 07:55:46 +00:00
|
|
|
print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
|
2023-09-19 06:20:26 +00:00
|
|
|
assert torch.allclose(
|
|
|
|
weight_in_rank, model.cache_weight_mgr.weight.detach()
|
|
|
|
), f"{weight_in_rank - model.cache_weight_mgr.weight}"
|
2022-08-10 05:44:30 +00:00
|
|
|
|
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
if rank == 0:
|
2023-09-19 06:20:26 +00:00
|
|
|
ref_model = torch.nn.EmbeddingBag.from_pretrained(
|
|
|
|
weight.detach().clone(), include_last_offset=True, freeze=False
|
|
|
|
).to(device)
|
2022-08-10 05:44:30 +00:00
|
|
|
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
set_seed(4321)
|
|
|
|
for i in range(5):
|
|
|
|
indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
|
|
|
|
res = model(indices, offsets)
|
|
|
|
|
|
|
|
grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
|
|
|
|
grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
|
|
|
|
res.backward(grad_in_rank)
|
|
|
|
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
res_list = gather_tensor(res.detach(), rank, world_size)
|
|
|
|
|
|
|
|
if rank == 0:
|
|
|
|
ref_res = ref_model(indices, offsets)
|
|
|
|
recover_res = torch.cat(res_list, dim=0)
|
|
|
|
|
|
|
|
assert torch.allclose(ref_res, recover_res)
|
|
|
|
|
|
|
|
ref_res.backward(grad)
|
|
|
|
ref_optimizer.step()
|
|
|
|
ref_optimizer.zero_grad()
|
|
|
|
|
|
|
|
model.cache_weight_mgr.flush()
|
2022-08-12 07:55:46 +00:00
|
|
|
weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
|
2022-08-10 05:44:30 +00:00
|
|
|
if rank == 0:
|
|
|
|
recover_weight = torch.cat(weight_list, dim=1)
|
|
|
|
assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"
|
|
|
|
|
|
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
2024-04-29 02:40:11 +00:00
|
|
|
colossalai.legacy.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
2022-09-01 09:55:41 +00:00
|
|
|
# run_parallel_freq_aware_embed_columnwise(rank, world_size)
|
|
|
|
run_parallel_freq_aware_embed_tablewise(rank, world_size)
|
2022-08-10 05:44:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
2023-09-19 06:20:26 +00:00
|
|
|
@pytest.mark.parametrize("world_size", [1, 4])
|
2022-08-10 05:44:30 +00:00
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
def test_parallel_freq_aware_embed(world_size):
|
2023-04-06 06:51:35 +00:00
|
|
|
spawn(run_dist, world_size)
|
2022-08-10 05:44:30 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if __name__ == "__main__":
|
2022-09-01 09:55:41 +00:00
|
|
|
# test_freq_aware_embed(True)
|
|
|
|
test_parallel_freq_aware_embed(2)
|
2022-08-30 06:50:02 +00:00
|
|
|
# test_lfu_strategy(False)
|