diff --git a/colossalai/tensor/_ops/embedding.py b/colossalai/tensor/_ops/embedding.py index 84c95492f..c5497431e 100644 --- a/colossalai/tensor/_ops/embedding.py +++ b/colossalai/tensor/_ops/embedding.py @@ -9,7 +9,7 @@ from packaging import version from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor: - # embedding_1Dcol split the weight(lookup table) + # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding) if not input_tensor.is_gathered(): @@ -25,6 +25,37 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwa output.gather() return output +def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor: + # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) + # Find index in this shard and mask those not here + # Reduce all + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Embedding) + if not input_tensor.is_gathered(): + input_tensor.gather() + + tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode) + num_embeddings_per_partition = weight.size(0) + vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition + vocab_end_index = vocab_start_index + num_embeddings_per_partition + + # Build the mask. + input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \ + (input_tensor.torch_tensor() >= vocab_end_index) + # Mask the input. + # TODO(jzy) masked_input may be an activation managed by ColoTensor. + masked_input = input_tensor.torch_tensor().clone() - vocab_start_index + masked_input[input_mask] = 0 + + partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(), + *args, **kwargs) + + # Mask the output embedding. + partial_output[input_mask, :] = 0. + # Reduce across all the model parallel GPUs. + output = reduce_input(partial_output, parallel_action.parallel_mode) + output = ColoTensor.init_from_torch_tensor(output) + return output + @colo_op_impl(torch.nn.functional.embedding) def colo_embedding(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. @@ -48,7 +79,9 @@ def colo_embedding(types, args, kwargs, pg): return ColoTensor.init_from_torch_tensor(output) elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied compute_patterns = weight.shard_spec.compute_patterns - if ComputePattern.TP1DCol_Embedding in compute_patterns: + if ComputePattern.TP1DRow_Embedding in compute_patterns: + return colo_embedding_1Drow(input_tensor, weight, args, kwargs) + elif ComputePattern.TP1DCol_Embedding in compute_patterns: return colo_embedding_1Dcol(input_tensor, weight, args, kwargs) else: raise NotImplementedError diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 06a751a77..5416e1662 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -166,6 +166,7 @@ class ColoTensor(object): dim = -1 self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim) self._shard_pattern = ShardPattern.NA + self._size = self._torch_tensor.size() def is_gathered(self) -> bool: return self._shard_pattern == ShardPattern.NA diff --git a/tests/components_to_test/simple_net.py b/tests/components_to_test/simple_net.py index 58c9835d8..8c7ba2863 100644 --- a/tests/components_to_test/simple_net.py +++ b/tests/components_to_test/simple_net.py @@ -5,7 +5,6 @@ from .utils.dummy_data_generator import DummyDataGenerator from .registry import non_distributed_component_funcs from colossalai.utils.cuda import get_current_device - class SimpleNet(CheckpointModule): """ In this no-leaf module, it has subordinate nn.modules and a nn.Parameter. @@ -13,12 +12,14 @@ class SimpleNet(CheckpointModule): def __init__(self, checkpoint=False) -> None: super().__init__(checkpoint=checkpoint) + self.embed = nn.Embedding(20, 4) self.proj1 = nn.Linear(4, 8) self.ln1 = nn.LayerNorm(8) self.proj2 = nn.Linear(8, 4) self.ln2 = nn.LayerNorm(4) def forward(self, x): + x = self.embed(x) x = self.proj1(x) x = self.ln1(x) x = self.proj2(x) @@ -26,11 +27,12 @@ class SimpleNet(CheckpointModule): return x + class DummyDataLoader(DummyDataGenerator): def generate(self): - data = torch.rand(16, 4, device=get_current_device()) - label = torch.randint(low=0, high=2, size=(16,), device=get_current_device()) + data = torch.randint(low=0, high=20, size=(16,20), device=get_current_device()) + label = torch.randint(low=0, high=2, size=(16,4), device=get_current_device()) return data, label diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py index 3b145ca1a..d1ea5cb3c 100644 --- a/tests/test_tensor/test_embedding_tp.py +++ b/tests/test_tensor/test_embedding_tp.py @@ -65,10 +65,60 @@ def run_embedding_tp1d_col_test(): W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank] check_equal(W_grad, layer.weight.grad) +def run_embedding_tp1d_row_test(): + device = get_current_device() + dtype = torch.float32 + DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D) + num_embeddings = 12 + embedding_dim = 32 + + local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer_master = torch.nn.Embedding(num_embeddings, embedding_dim) + layer = torch.nn.Embedding(num_embeddings, embedding_dim) + + A_master = torch.tensor((0,3,6,9), device=device) + A = broadcast_tensor_chunk(A_master, chunk_size=1) + + W_shape = (num_embeddings, embedding_dim) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + W = broadcast_tensor_chunk(W_master, chunk_size=1) + W.requires_grad = True + + # replace the torch nn.Parameters with ColoTensor + sharded_weight = ColoTensor.init_from_torch_tensor(W) + parallel_action_list = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, + parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec = TensorSpec(parallel_action_list) + sharded_weight.set_spec(spec) # reshard + replace_parameter_add_grad(layer, sharded_weight) + out = layer(A) + + replace_parameter_add_grad(layer_master, W_master) + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad = broadcast_tensor_chunk(grad_master, chunk_size=1) + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[local_rank] + check_equal(W_grad, layer.weight.grad) + def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_embedding_tp1d_col_test() + run_embedding_tp1d_row_test() @pytest.mark.dist @parameterize('world_size', [1, 4]) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 7610c5d8d..b56892e6d 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -47,6 +47,11 @@ def run_1d_col_tp(): ] spec_col = TensorSpec(parallel_action_list_col) + parallel_action_list_embedding_col = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) + set_seed(1) if rank == 0: model_torch = model_builder(checkpoint=True) @@ -60,6 +65,8 @@ def run_1d_col_tp(): p.set_spec(spec_col) if 'proj2' in name and 'weight' in name: p.set_spec(spec_row) + if 'embed' in name and 'weight' in name: + p.set_spec(spec_embedding_col) model = model.cuda() @@ -172,6 +179,11 @@ def run_1d_row_tp(): ] spec = TensorSpec(parallel_action_list) + parallel_action_list_embedding_row = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) + set_seed(1) if rank == 0: model_torch = model_builder(checkpoint=True) @@ -183,6 +195,8 @@ def run_1d_row_tp(): continue if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: p.set_spec(spec) + if 'embed' in name and 'weight' in name: + p.set_spec(spec_embedding_row) model = model.cuda() @@ -227,7 +241,7 @@ def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') run_1d_row_tp() - + run_1d_col_tp() @pytest.mark.dist @parameterize('world_size', [1, 4]) @@ -238,6 +252,6 @@ def test_simple_net(world_size): if __name__ == '__main__': - # test_simple_net() + test_simple_net() # test_model_parameters() - test_colo_optimizer() + # test_colo_optimizer()