From 22717a856feed71b41671bd2c38fe2f5ad00bdb5 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 22 Jun 2022 15:54:03 +0800 Subject: [PATCH] [tensor] add embedding bag op (#1156) --- colossalai/nn/_ops/__init__.py | 1 + colossalai/nn/_ops/embedding_bag.py | 122 +++++++++++++++++++++ tests/test_tensor/test_embedding_bag_tp.py | 56 ++++++++++ 3 files changed, 179 insertions(+) create mode 100644 colossalai/nn/_ops/embedding_bag.py create mode 100644 tests/test_tensor/test_embedding_bag_tp.py diff --git a/colossalai/nn/_ops/__init__.py b/colossalai/nn/_ops/__init__.py index e9ce2b1ff..784f0abc6 100644 --- a/colossalai/nn/_ops/__init__.py +++ b/colossalai/nn/_ops/__init__.py @@ -4,3 +4,4 @@ from .layernorm import colo_layernorm from .loss import colo_cross_entropy from .embedding import colo_embedding from .addmm import colo_addmm +from .embedding_bag import colo_embedding_bag diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py new file mode 100644 index 000000000..bf9dcbdd1 --- /dev/null +++ b/colossalai/nn/_ops/embedding_bag.py @@ -0,0 +1,122 @@ +import torch.nn.functional as F +from typing import Optional +from torch import Tensor +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec +from ._utils import GeneralTensor, convert_to_colo_tensor + + +def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) + # Gather splitted lookup table + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + + output_parallel = F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + output_spec = TensorSpec( + distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), + ParallelAction(ComputePattern.TP1D)) + output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) + if weight.spec.parallel_action.gather_out: + output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + return output + + +def colo_embedding_bag_1d(tp_mode: str, + input_tensor: ColoTensor, + weight: ColoTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None) -> ColoTensor: + assert tp_mode in ('col',) + funcs = {'col': colo_embedding_bag_1Dcol} + return funcs[tp_mode](input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + + +@colo_op_impl(F.embedding_bag) +def colo_embedding_bag(input_tensor: GeneralTensor, + weight: GeneralTensor, + offsets: Optional[Tensor] = None, + max_norm: Optional[float] = None, + norm_type: float = 2, + scale_grad_by_freq: bool = False, + mode: str = "mean", + sparse: bool = False, + per_sample_weights: Optional[Tensor] = None, + include_last_offset: bool = False, + padding_idx: Optional[int] = None): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. + This method looks up an embedding table. + """ + input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight))) + + # Handle differen parallel actions. + + if not weight.has_spec(): # No Model Parallel Applied + assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' + return ColoTensor.from_torch_tensor( + F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx)) + elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.spec.is_1D_col(): + tp_mode = 'col' + else: + raise NotImplementedError + return colo_embedding_bag_1d(tp_mode, + input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx) + else: + raise NotImplementedError diff --git a/tests/test_tensor/test_embedding_bag_tp.py b/tests/test_tensor/test_embedding_bag_tp.py new file mode 100644 index 000000000..61f7d137d --- /dev/null +++ b/tests/test_tensor/test_embedding_bag_tp.py @@ -0,0 +1,56 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.tensor import ColoTensor, distspec, ColoParameter +from torch.nn import functional as F +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager +from _utils import tensor_equal, tensor_shard_equal + + +def init_1d_col(weight): + spec = TensorSpec( + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + weight.set_spec(spec) + + +def run_with_spec(spec_init_func): + model = torch.nn.EmbeddingBag(10, 4).cuda() + weight = ColoParameter(model.weight.clone()) + spec_init_func(weight) + inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda() + offsets = torch.tensor([0, 4]).cuda() + out = model(inputs, offsets=offsets) + colo_out = F.embedding_bag(inputs, weight, offsets=offsets) + assert tensor_equal(out, colo_out) + grad = torch.rand_like(out) + out.backward(grad) + colo_out.backward(grad) + assert tensor_shard_equal(model.weight.grad, weight.grad) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(init_1d_col) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_bag_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_bag_1d(4)