diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py index a22d0511f..82ed07356 100644 --- a/colossalai/nn/_ops/_utils.py +++ b/colossalai/nn/_ops/_utils.py @@ -195,3 +195,39 @@ def split_forward_gather_backward(input_, process_group, dim): def gather_forward_split_backward(input_, process_group, dim): return _GatherForwardSplitBackward.apply(input_, process_group, dim) + + +def _all_to_all(x: torch.Tensor, pg: ProcessGroup, scatter_dim: int, gather_dim: int) -> torch.Tensor: + world_size = pg.tp_world_size() + if world_size == 1: + return x + + # TODO: enabling mpi backend to support CPU all_to_all + assert x.device.type == 'cuda', f"Currently, the collective function dual_all_to_all only supports nccl backend" + + shapes = list(x.size()) + shapes[scatter_dim] = shapes[scatter_dim] // world_size + + scatter_list = [each.contiguous() for each in torch.tensor_split(x, world_size, scatter_dim)] + gather_list = [torch.empty(*shapes, dtype=x.dtype, device=x.device) for _ in range(world_size)] + torch.distributed.all_to_all(gather_list, scatter_list, group=pg.tp_process_group()) + + return torch.cat(gather_list, dim=gather_dim).contiguous() + + +class _DualAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, pg, scatter_dim, gather_dim): + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.pg = pg + return _all_to_all(x, pg, scatter_dim, gather_dim) + + @staticmethod + def backward(ctx, grad): + return _all_to_all(grad, ctx.pg, ctx.gather_dim, ctx.scatter_dim), None, None, None + + +def dual_all_to_all(x, pg, scatter_dim: int, gather_dim: int): + return _DualAllToAll.apply(x, pg, scatter_dim, gather_dim) diff --git a/colossalai/nn/_ops/cache_embedding/__init__.py b/colossalai/nn/_ops/cache_embedding/__init__.py index 0510e89f6..10dbe1c8a 100644 --- a/colossalai/nn/_ops/cache_embedding/__init__.py +++ b/colossalai/nn/_ops/cache_embedding/__init__.py @@ -1,5 +1,6 @@ from .cache_mgr import CachedParamMgr from .copyer import LimitBuffIndexCopyer from .freq_aware_embedding import FreqAwareEmbeddingBag +from .parallel_freq_aware_embedding import ParallelFreqAwareEmbeddingBag -__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag'] +__all__ = ['CachedParamMgr', 'LimitBuffIndexCopyer', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag'] diff --git a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py new file mode 100644 index 000000000..4400d6fc2 --- /dev/null +++ b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py @@ -0,0 +1,136 @@ +import torch +import torch.nn.functional as F +from typing import List, Optional, Iterator, Tuple + +from .base_embedding import BaseEmbeddingBag +from .cache_mgr import CachedParamMgr +from torch.nn.parameter import Parameter +from .._utils import dual_all_to_all + +from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup + + +def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: + if world_size == 1: + return 0, embedding_dim, True + + assert embedding_dim >= world_size, \ + f"Embedding dimension {embedding_dim} must be larger than the world size " \ + f"{world_size} of the process group" + chunk_size = embedding_dim // world_size + threshold = embedding_dim % world_size + # if embedding dim is divisible by world size + if threshold == 0: + return rank * chunk_size, (rank + 1) * chunk_size, True + + # align with the split strategy of torch.tensor_split + size_list = [chunk_size + 1 if i < threshold else chunk_size for i in range(world_size)] + offset = sum(size_list[:rank]) + return offset, offset + size_list[rank], False + + +class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): + + def __init__(self, + num_embeddings, + embedding_dim, + padding_idx=None, + max_norm=None, + norm_type=2., + scale_grad_by_freq=False, + sparse=False, + _weight=None, + mode='mean', + include_last_offset=False, + dtype=None, + debug=True): + super(ParallelFreqAwareEmbeddingBag, + self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, + sparse, mode, include_last_offset) + + self.rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + self.debug = debug + + self.partition_start_index, self.partition_end_index, divisible = get_partition( + embedding_dim, self.rank, self.world_size) + self.embedding_dim_per_partition = self.partition_end_index - self.partition_start_index + + if _weight is None: + self._weight.process_group = ProcessGroup(tp_degree=self.world_size) + self._weight = ColoParameter.from_torch_tensor(torch.empty(self.num_embeddings, + self.embedding_dim_per_partition, + device='cpu', + dtype=dtype), + requires_grad=True, + spec=ShardSpec(dims=[-1], num_partitions=[self.world_size])) + self.init_parameters() + else: + assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter" + _weight.process_group = ProcessGroup(tp_degree=self.world_size) + _weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]), + ComputeSpec(ComputePattern.TP1D)) + self._weight = _weight + + @property + def weight(self): + return self.cache_weight_mgr.cpu_weight + + def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]: + yield 'weight', self.cache_weight_mgr.cuda_cached_weight + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + yield self.cache_weight_mgr.cuda_cached_weight + + @torch.no_grad() + def init_parameters(self): + self._weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings) + if self.padding_idx is not None: + self._weight[self.padding_idx].fill_(0) + + def preprocess(self, + cuda_row_num: int, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7, + buffer_size: int = 50_000): + self.cache_weight_mgr = CachedParamMgr(self._weight, cuda_row_num, buffer_size=buffer_size) + self.cache_weight_mgr.reorder(ids_freq_mapping, warmup_ratio) + + def forward(self, indices, offsets=None, per_sample_weights=None, shape_hook=None, scatter_dim=0, gather_dim=-1): + with torch.no_grad(): + reorder_ids = self.cache_weight_mgr.prepare_ids(indices) + + output_shard = F.embedding_bag(reorder_ids, self.cache_weight_mgr.cuda_cached_weight, offsets, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.mode, self.sparse, + per_sample_weights, self.include_last_offset, self.padding_idx) + + if shape_hook is not None: + output_shard = shape_hook(output_shard) + + output_full = dual_all_to_all(output_shard, + self._weight.get_process_group(), + scatter_dim=scatter_dim, + gather_dim=gather_dim) + return output_full + + @classmethod + def from_pretrained(cls, + embedding: torch.Tensor, + freeze: bool = True, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2., + scale_grad_by_freq: bool = False, + sparse: bool = False, + mode: str = 'mean', + include_last_offset: bool = False, + debug: bool = True, + cuda_row_num: int = 100_000, + ids_freq_mapping: Optional[List[int]] = None, + warmup_ratio: float = 0.7) -> 'ParallelFreqAwareEmbeddingBag': + rows, cols = embedding.shape + embedding_bag = cls(rows, cols, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, embedding, mode, + include_last_offset, debug) + embedding_bag.preprocess(cuda_row_num, ids_freq_mapping, warmup_ratio) + embedding_bag.cache_weight_mgr.cuda_cached_weight.requires_grad_ = not freeze + return embedding_bag diff --git a/tests/test_tensor/ops/test_cache_embedding.py b/tests/test_tensor/ops/test_cache_embedding.py index 688d59b91..ac5b3bc40 100644 --- a/tests/test_tensor/ops/test_cache_embedding.py +++ b/tests/test_tensor/ops/test_cache_embedding.py @@ -3,9 +3,13 @@ from functools import partial import torch import torch.multiprocessing as mp import numpy as np +import random +import colossalai from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use +from colossalai.tensor import ColoParameter +from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag @@ -13,6 +17,15 @@ NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 +def set_seed(seed): + """ + To achieve reproducible results, it's necessary to fix random seeds + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + def synthesize_1d_sparse_feature( batch_size, num_embed, @@ -128,7 +141,91 @@ def test_freq_aware_embed(): f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}" +def gather_tensor(tensor, rank, world_size): + gather_list = [] + if rank == 0: + gather_list = [torch.empty_like(tensor) for _ in range(world_size)] + + torch.distributed.gather(tensor, gather_list, dst=0) + return gather_list + + +def run_parallel_freq_aware_embed(rank, world_size): + device = torch.device('cuda', torch.cuda.current_device()) + + num_embed = 100 + embed_dim = 16 + batch_size = 4 + + set_seed(4321) + weight = torch.rand(num_embed, embed_dim) + coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False) + + model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, + include_last_offset=True, + freeze=False, + cuda_row_num=batch_size * 2) + + assert model.cache_weight_mgr.cpu_weight.device.type == 'cpu' + assert model.cache_weight_mgr.cuda_cached_weight.requires_grad + weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank] + assert torch.allclose( + weight_in_rank, + model.cache_weight_mgr.cpu_weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.cpu_weight}" + + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + + if rank == 0: + ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(), + include_last_offset=True, + freeze=False).to(device) + ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3) + + set_seed(4321) + for i in range(5): + indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device) + res = model(indices, offsets) + + grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device) + grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank] + res.backward(grad_in_rank) + + optimizer.step() + optimizer.zero_grad() + + res_list = gather_tensor(res.detach(), rank, world_size) + + if rank == 0: + ref_res = ref_model(indices, offsets) + recover_res = torch.cat(res_list, dim=0) + + assert torch.allclose(ref_res, recover_res) + + ref_res.backward(grad) + ref_optimizer.step() + ref_optimizer.zero_grad() + + model.cache_weight_mgr.flush() + weight_list = gather_tensor(model.cache_weight_mgr.cpu_weight.detach().cuda(), rank, world_size) + if rank == 0: + recover_weight = torch.cat(weight_list, dim=1) + assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}" + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_parallel_freq_aware_embed(rank, world_size) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_parallel_freq_aware_embed(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': # test_freq_aware_embed() # test_chunkmgr_admit() - pass + test_parallel_freq_aware_embed(2)