From 5156d5b4f8bdf3f402b7a1fd5786eff2d2266c8b Mon Sep 17 00:00:00 2001 From: CsRic <59389055+CsRic@users.noreply.github.com> Date: Thu, 1 Sep 2022 17:55:41 +0800 Subject: [PATCH] [embedding] add tablewise sharding for FAW (#1526) --- colossalai/nn/parallel/layers/__init__.py | 5 +- .../layers/cache_embedding/__init__.py | 4 +- .../cache_embedding/freq_aware_embedding.py | 2 +- .../parallel_freq_aware_embedding.py | 2 - ...parallel_freq_aware_embedding_tablewise.py | 192 ++++++++++++++++++ tests/test_layers/test_cache_embedding.py | 81 +++++++- 6 files changed, 273 insertions(+), 13 deletions(-) create mode 100644 colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index 1847e0e05..ee20fc65b 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -3,10 +3,11 @@ from .linear import ColoLinear from .embedding import ColoEmbedding from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module -from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy +from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ + ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig __all__ = [ 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr', - 'LimitBuffIndexCopyer', 'EvictionStrategy' + 'LimitBuffIndexCopyer', 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index e3644dc9c..1622f848c 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -2,8 +2,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy from .copyer import LimitBuffIndexCopyer from .freq_aware_embedding import FreqAwareEmbeddingBag from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag - +from .parallel_freq_aware_embedding_tablewise import ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig __all__ = [ 'CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', - 'EvictionStrategy' + 'EvictionStrategy', 'ParallelFreqAwareEmbeddingBagTablewise', 'TablewiseEmbeddingBagConfig' ] diff --git a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py index ca911f9d1..ad6e39b7d 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/freq_aware_embedding.py @@ -99,7 +99,7 @@ class FreqAwareEmbeddingBag(BaseEmbeddingBag): def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None): with torch.no_grad(): reorder_ids = self.cache_weight_mgr.prepare_ids(indices) - + embeddings = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py index 8213926ae..5c2f65b76 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding.py @@ -79,10 +79,8 @@ class ParallelFreqAwareEmbeddingBag(FreqAwareEmbeddingBag): output_shard = F.embedding_bag(reorder_ids.cuda(), self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, per_sample_weights, self.include_last_offset, self.padding_idx) - if shape_hook is not None: output_shard = shape_hook(output_shard) - output_full = dual_all_to_all(output_shard, self.weight.get_process_group(), scatter_dim=scatter_dim, diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py new file mode 100644 index 000000000..f22f3f993 --- /dev/null +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_freq_aware_embedding_tablewise.py @@ -0,0 +1,192 @@ +import torch +import torch.nn.functional as F +import torch.distributed as dist +import torch.nn as nn +from typing import List, Optional, Iterator, Tuple +import abc + +from .freq_aware_embedding import FreqAwareEmbeddingBag + +from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor +from .cache_mgr import CachedParamMgr, EvictionStrategy + + +class TablewiseEmbeddingBagConfig: + ''' + example: + def prepare_tablewise_config(args, cache_ratio, ...): + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + ... + return embedding_bag_config_list + ''' + def __init__(self, + num_embeddings: int, + cuda_row_num: int, + assigned_rank: int = 0, + buffer_size=50_000, + ids_freq_mapping=None, + initial_weight: torch.tensor = None, + name: str = ""): + self.num_embeddings = num_embeddings + self.cuda_row_num = cuda_row_num + self.assigned_rank = assigned_rank + self.buffer_size = buffer_size + self.ids_freq_mapping = ids_freq_mapping + self.initial_weight = initial_weight + self.name = name + + + +def _all_to_all_for_tablewise(x: torch.Tensor, pg: ProcessGroup, scatter_strides: List[int], gather_strides: List[int], forward=True) -> torch.Tensor: + world_size = pg.tp_world_size() + rank = pg.tp_local_rank() + if world_size == 1: + return x + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + if forward: + scatter_list = list(x.split(scatter_strides, 0)) + gather_list = [torch.empty(scatter_strides[rank], gather_strides[i], dtype=x.dtype, + device=x.device) for i in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 1).contiguous() + else: + # split on dim 1, lose contiguity + scatter_list = [each.contiguous() for each in x.split(scatter_strides, 1)] + gather_list = [torch.empty(gather_strides[i], scatter_strides[rank], dtype=x.dtype, + device=x.device) for i in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + return torch.cat(gather_list, 0).contiguous() + + +class _DualAllToAllForTablewise(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_strides, gather_strides): + ctx.pg = pg + ctx.scatter_strides = scatter_strides + ctx.gather_strides = gather_strides + return _all_to_all_for_tablewise(x, pg, scatter_strides, gather_strides, forward=True) + + @staticmethod + def backward(ctx, grad): + return _all_to_all_for_tablewise(grad, ctx.pg, ctx.gather_strides, ctx.scatter_strides, forward=False), None, None, None + + +def _dual_all_to_all(x, pg, scatter_strides, gather_strides): + return _DualAllToAllForTablewise.apply(x, pg, scatter_strides, gather_strides) + + +class ParallelFreqAwareEmbeddingBagTablewise(abc.ABC, nn.Module): + ''' + every table assigned to this class instance is managed by a FreqAwareEmbeddingBag. + ''' + + def __init__(self, + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig], + embedding_dim: int, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + mode='mean', + include_last_offset=False, + dtype=None, + device=None, + warmup_ratio=0.7, + pin_weight=False, + evict_strategy: EvictionStrategy = EvictionStrategy.LFU): + super(ParallelFreqAwareEmbeddingBagTablewise, self).__init__() + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.global_table_assign_list = [config.assigned_rank for config in embedding_bag_config_list] + self.global_table_num_embeddings_list = [config.num_embeddings for config in embedding_bag_config_list] + self.global_tables_num = len(embedding_bag_config_list) + self.global_tables_offsets = torch.cumsum(torch.tensor([0] + self.global_table_num_embeddings_list), 0) + + self.assigned_table_list: List[int] = [] + for i, rank in enumerate(self.global_table_assign_list): + if rank == self.rank: + self.assigned_table_list.append(i) + self.include_last_offset = include_last_offset + self.pg = ProcessGroup(tp_degree=self.world_size) + + # prepare FreqAwareEmbeddingBag list + + self.freq_aware_embedding_bag_list: nn.ModuleList = nn.ModuleList() + for config in embedding_bag_config_list: + if config.assigned_rank != self.rank: + continue + self.freq_aware_embedding_bag_list.append( + FreqAwareEmbeddingBag( + num_embeddings=config.num_embeddings, + embedding_dim=embedding_dim, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + _weight=config.initial_weight, + mode=mode, + include_last_offset=include_last_offset, + dtype=dtype, + device=device, + cuda_row_num=config.cuda_row_num , + ids_freq_mapping=config.ids_freq_mapping, + warmup_ratio=warmup_ratio, + buffer_size=config.buffer_size, + pin_weight=pin_weight, + evict_strategy=evict_strategy + ) + ) + + # prepare list shape for all_to_all output + self.embedding_dim_per_rank = [0 for i in range(self.world_size)] + for rank in self.global_table_assign_list: + self.embedding_dim_per_rank[rank] += embedding_dim + + #print("global_table_assign_list {}".format(self.global_table_assign_list)) + #print("global_table_num_embeddings_list {}".format(self.global_table_num_embeddings_list)) + #print("global_tables_offsets {}".format(self.global_tables_offsets)) +# + def forward(self, indices: torch.Tensor, offsets: torch.Tensor = None, per_sample_weights=None, shape_hook=None): + # determine indices to handle + batch_size = (offsets.shape[0]) // self.global_tables_num + local_output_list = [] + for i, handle_table in enumerate(self.assigned_table_list): + indices_start_position = offsets[batch_size * handle_table] + if (not self.include_last_offset) and (batch_size * (handle_table + 1) >= indices.shape[0]): + # till the end special case + indices_end_position = indices.shape[0] + else : + indices_end_position = offsets[batch_size * (handle_table + 1)] + + local_indices = indices[indices_start_position:indices_end_position] - \ + self.global_tables_offsets[handle_table] + if self.include_last_offset: + local_offsets = offsets[batch_size * handle_table:batch_size + * (handle_table + 1) + 1] - offsets[batch_size * (handle_table)] + else: + local_offsets = offsets[batch_size * handle_table:batch_size + * (handle_table + 1)] - offsets[batch_size * (handle_table)] + local_per_sample_weights = None + if per_sample_weights != None: + local_per_sample_weights = per_sample_weights[indices_start_position:indices_end_position] + local_output_list.append( + self.freq_aware_embedding_bag_list[i]( + local_indices, + local_offsets, + local_per_sample_weights + ) + ) + + # get result of shape = (batch_size, (len(assigned_table_list)*embedding_dim)) + local_output = torch.cat(local_output_list, 1) + # then concatenate those local_output on the second demension. + # use all_to_all + remains = batch_size % self.world_size + scatter_strides = [batch_size // self.world_size + int(i < remains) for i in range(self.world_size)] + output_full = _dual_all_to_all(local_output, self.pg, scatter_strides, self.embedding_dim_per_rank) + if shape_hook is not None: + output_full = shape_hook(output_full) + return output_full diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index 50fbb732c..2a398719e 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -12,7 +12,9 @@ from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy +from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, EvictionStrategy, \ + ParallelFreqAwareEmbeddingBagTablewise, TablewiseEmbeddingBagConfig +from typing import List NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -200,7 +202,72 @@ def gather_tensor(tensor, rank, world_size): return gather_list -def run_parallel_freq_aware_embed(rank, world_size): +def run_parallel_freq_aware_embed_tablewise(rank, world_size): + if world_size != 2: + return + device = torch.device('cuda', torch.cuda.current_device()) + + # initialize weight + # 3 feature tables. idx: 0~5, 6~10, 11~17 + weight_table1 = torch.rand(6, 5) + weight_table2 = torch.rand(5, 5) + weight_table3 = torch.rand(7, 5) + embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = [] + embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( + num_embeddings=6, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table1.clone().detach().cpu())) + embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( + num_embeddings=5, cuda_row_num=4, assigned_rank=0, initial_weight=weight_table2.clone().detach().cpu())) + embedding_bag_config_list.append(TablewiseEmbeddingBagConfig( + num_embeddings=7, cuda_row_num=4, assigned_rank=1, initial_weight=weight_table3.clone().detach().cpu())) + + model = ParallelFreqAwareEmbeddingBagTablewise( + embedding_bag_config_list, + embedding_dim=5, + evict_strategy=EvictionStrategy.LFU, + include_last_offset=True + ) + # demo explain: + ''' + batch feature 1 feature 2 feature 3 + input0 [1,2,3] [6,7] [] + input1 [] [9] [13,15] + input2 [1,5] [6,8] [11] + ↑ ↑ ↑ + rank 0 rank 0 rank 1 + in KJT format + ''' + res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device), + torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device)) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) + rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device) + if rank == 0: + fake_grad = rand_grad[0:2] + else : + fake_grad = rand_grad[2:] + + res.backward(fake_grad) + optimizer.step() + optimizer.zero_grad() + + # check correctness on weight_table2 + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_table2.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2) + ref_grad = rand_grad[:, 5:10] + ref_res = ref_model(torch.tensor([0, 1, 3, 0, 2], device=device), torch.tensor([0, 2, 3, 5], device=device)) + ref_res.backward(ref_grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.freq_aware_embedding_bag_list[1].cache_weight_mgr.flush() # update cpu weight + recover_weight = model.freq_aware_embedding_bag_list[1].cache_weight_mgr.weight + assert torch.allclose(recover_weight, ref_model.weight.detach().cpu() + ), f"{recover_weight - ref_model.weight.detach().cpu()}" + + +def run_parallel_freq_aware_embed_columnwise(rank, world_size): device = torch.device('cuda', torch.cuda.current_device()) num_embed = 100 @@ -219,7 +286,8 @@ def run_parallel_freq_aware_embed(rank, world_size): model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, include_last_offset=True, freeze=False, - cuda_row_num=batch_size * 2) + cuda_row_num=batch_size * 2, + ) assert model.cache_weight_mgr.weight.device.type == 'cpu' assert model.cache_weight_mgr.cuda_cached_weight.requires_grad @@ -269,7 +337,8 @@ def run_parallel_freq_aware_embed(rank, world_size): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_parallel_freq_aware_embed(rank, world_size) + # run_parallel_freq_aware_embed_columnwise(rank, world_size) + run_parallel_freq_aware_embed_tablewise(rank, world_size) @pytest.mark.dist @@ -281,6 +350,6 @@ def test_parallel_freq_aware_embed(world_size): if __name__ == '__main__': - test_freq_aware_embed(True) - # test_parallel_freq_aware_embed(2) + # test_freq_aware_embed(True) + test_parallel_freq_aware_embed(2) # test_lfu_strategy(False)