From 3af7e65deaa45e2c05d22acf44295af248e7c12b Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Thu, 8 Dec 2022 10:04:09 +0800
Subject: [PATCH] [autoparallel] complete gpt related module search (#2097)

---
 .../node_handler/linear_handler.py            |  37 ++++---
 .../test_node_handler/test_linear_handler.py  |  89 ++++++++++------
 ...=> test_solver_with_gpt_related_module.py} | 100 +++++++++++++++++-
 3 files changed, 173 insertions(+), 53 deletions(-)
 rename tests/test_auto_parallel/test_tensor_shard/{test_solver_with_gpt_block.py => test_solver_with_gpt_related_module.py} (71%)

diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
index 2bb852dfa..659edf548 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
@@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
     last_physical_output_dims = output_op_data.data.dim() - 1
 
     if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
-        update_partition_dim(
-            sharding_spec=input_sharding_spec,
-            dim_mapping={last_logical_input_dims: last_physical_input_dims},
-            physical_shape=input_op_data.data.shape,
-            inplace=True,
-        )
+        input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
+    else:
+        input_last_dim_mapping = {}
 
     if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
-        update_partition_dim(
-            sharding_spec=output_sharding_spec,
-            dim_mapping={last_logical_output_dims: last_physical_output_dims},
-            physical_shape=output_op_data.data.shape,
-            inplace=True,
-        )
+        output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
+    else:
+        output_last_dim_mapping = {}
 
     # get logger for debug message
     logger = get_dist_logger()
@@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
             output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
             try:
                 # replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
+                input_dim_mapping = {0: i}
+                input_dim_mapping.update(input_last_dim_mapping)
+
                 update_partition_dim(sharding_spec=input_sharding_spec,
-                                     dim_mapping={0: i},
+                                     dim_mapping=input_dim_mapping,
                                      physical_shape=input_op_data.data.shape,
                                      inplace=True)
+                output_dim_mapping = {0: i}
+                output_dim_mapping.update(output_last_dim_mapping)
+
                 update_partition_dim(sharding_spec=output_sharding_spec,
-                                     dim_mapping={0: i},
+                                     dim_mapping=output_dim_mapping,
                                      physical_shape=output_op_data.data.shape,
                                      inplace=True)
                 strategy_copy.name = f'{strategy.name}_{i}'
@@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
         output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
 
         # after updating, the logical shape will be replaced by the physical shape
+        input_dim_mapping = {}
+        input_dim_mapping.update(input_last_dim_mapping)
         update_partition_dim(sharding_spec=input_sharding_spec,
-                             dim_mapping={},
+                             dim_mapping=input_dim_mapping,
                              physical_shape=input_op_data.data.shape,
                              inplace=True)
+
+        output_dim_mapping = {}
+        output_dim_mapping.update(output_last_dim_mapping)
         update_partition_dim(sharding_spec=output_sharding_spec,
-                             dim_mapping={},
+                             dim_mapping=output_dim_mapping,
                              physical_shape=output_op_data.data.shape,
                              inplace=True)
         sharding_strategies.append(strategy_copy)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
index e0130936d..fb8821fae 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
@@ -26,18 +26,21 @@ from colossalai.utils import free_port
 from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
 
 
-def check_linear_module_handler(rank, bias, world_size, port):
+def check_linear_module_handler(rank, bias, input_shape, world_size, port):
     disable_existing_loggers()
     launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
     model = nn.Sequential(nn.Linear(16, 32, bias=bias)).cuda()
     physical_mesh_id = torch.arange(0, 4)
     mesh_shape = (2, 2)
     device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-    input = torch.rand(4, 4, 4, 16).cuda()
+    input = torch.rand(input_shape).cuda()
     # the index of linear node in computation graph
     node_index = 1
     # strategy number of linear node
-    strategy_number = 24
+    if input_shape == (1, 4, 4, 16):
+        strategy_number = 19
+    else:
+        strategy_number = 24
     # construct input args
     input_args = [input]
     # construct meta arg names
@@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
                                      meta_arg_names=meta_arg_names)
 
     tracer = ColoTracer()
-    graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
+    graph = tracer.trace(model, meta_args={"input": torch.rand(input_shape).to('meta')})
     gm = ColoGraphModule(model, graph)
 
     linear_mod_node = list(graph.nodes)[1]
@@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
         assert op_data.data is not None
 
     assert mapping['input'].name == "input_1"
-    assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
+    assert mapping['input'].data.shape == torch.Size(input_shape)
     assert mapping['input'].type == OperationDataType.ARG
-    assert mapping['input'].logical_shape == torch.Size([64, 16])
+    input_logical_shape = mapping['input'].data.view(-1, 16).shape
+    assert mapping['input'].logical_shape == input_logical_shape
 
     assert mapping['other'].name == "weight"
     assert mapping['other'].data.shape == torch.Size([32, 16])
@@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
         assert mapping['bias'].logical_shape == torch.Size([32])
 
     assert mapping['output'].name == "_0"
-    assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
+    output_shape = input_shape[:-1] + (32,)
+    assert mapping['output'].data.shape == torch.Size(output_shape)
     assert mapping['output'].type == OperationDataType.OUTPUT
-    assert mapping['output'].logical_shape == torch.Size([64, 32])
+    output_logical_shape = mapping['output'].data.view(-1, 32).shape
+    assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
 
     strategies_vector = handler.register_strategy(compute_resharding_cost=False)
     strategy_name_list = [val.name for val in strategies_vector]
-    # one strategy will be converted to different physical sharding spec
-    assert len(strategy_name_list) > 8
+
+    # First dimension cannot be shard if input shape is (1, 4, 4, 16)
+    if input_shape != (1, 4, 4, 16):
+        assert 'S1S0 = S1R x RS0_0' in strategy_name_list
+        assert 'S0S1 = S0R x RS1_0' in strategy_name_list
+        assert 'S1R = S1S0 x S0R_0' in strategy_name_list
+        assert 'S0R = S0S1 x S1R_0' in strategy_name_list
+        assert 'S01R = S01R x RR_0' in strategy_name_list
 
     # SS = SR x RS
-    assert 'S0S1 = S0R x RS1_0' in strategy_name_list
     assert 'S0S1 = S0R x RS1_1' in strategy_name_list
     assert 'S0S1 = S0R x RS1_2' in strategy_name_list
-    assert 'S1S0 = S1R x RS0_0' in strategy_name_list
     assert 'S1S0 = S1R x RS0_1' in strategy_name_list
     assert 'S1S0 = S1R x RS0_2' in strategy_name_list
 
     # SR = SS x SR
-    assert 'S0R = S0S1 x S1R_0' in strategy_name_list
     assert 'S0R = S0S1 x S1R_1' in strategy_name_list
     assert 'S0R = S0S1 x S1R_2' in strategy_name_list
-    assert 'S1R = S1S0 x S0R_0' in strategy_name_list
     assert 'S1R = S1S0 x S0R_1' in strategy_name_list
     assert 'S1R = S1S0 x S0R_2' in strategy_name_list
 
@@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
     assert 'RS1 = RR x RS1' in strategy_name_list
 
     # S01R = S01R x RR
-    assert 'S01R = S01R x RR_0' in strategy_name_list
     assert 'S01R = S01R x RR_1' in strategy_name_list
     assert 'S01R = S01R x RR_2' in strategy_name_list
 
@@ -164,7 +171,7 @@ class LinearModel(nn.Module):
         return x
 
 
-def check_linear_function_handler(rank, bias, world_size, port):
+def check_linear_function_handler(rank, bias, input_shape, world_size, port):
     disable_existing_loggers()
     launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
     model = LinearModel().cuda()
@@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
     mesh_shape = (2, 2)
     device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
 
-    input = torch.rand(4, 4, 4, 16).cuda()
+    input = torch.rand(input_shape).cuda()
     other = torch.rand(32, 16).cuda()
     # the index of linear node in computation graph
     node_index = 2
     # strategy number of linear node
-    strategy_number = 24
+    if input_shape == (1, 4, 4, 16):
+        strategy_number = 19
+    else:
+        strategy_number = 24
     # construct input args
     input_args = [input, other]
     # construct meta arg names
@@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
     tracer = ColoTracer()
     graph = tracer.trace(model,
                          meta_args={
-                             "input": torch.rand(4, 4, 4, 16).to('meta'),
+                             "input": torch.rand(input_shape).to('meta'),
                              'others': torch.rand(32, 16).to('meta')
                          })
     gm = ColoGraphModule(model, graph)
@@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
     mapping = handler.get_operation_data_mapping()
 
     assert mapping['input'].name == "input_1"
-    assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
+    assert mapping['input'].data.shape == torch.Size(input_shape)
     assert mapping['input'].type == OperationDataType.ARG
-    assert mapping['input'].logical_shape == torch.Size([64, 16])
+    input_logical_shape = mapping['input'].data.view(-1, 16).shape
+    assert mapping['input'].logical_shape == torch.Size(input_logical_shape)
 
     assert mapping['other'].name == "others"
     assert mapping['other'].data.shape == torch.Size([32, 16])
@@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
         assert mapping['other'].logical_shape == torch.Size([16, 32])
 
     assert mapping['output'].name == "linear"
-    assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
+    output_shape = input_shape[:-1] + (32,)
+    assert mapping['output'].data.shape == torch.Size(output_shape)
     assert mapping['output'].type == OperationDataType.OUTPUT
+    output_logical_shape = mapping['output'].data.view(-1, 32).shape
+    assert mapping['output'].logical_shape == torch.Size(output_logical_shape)
 
     strategies_vector = handler.register_strategy(compute_resharding_cost=False)
     strategy_name_list = [val.name for val in strategies_vector]
-    # one strategy will be converted to different physical sharding spec
-    assert len(strategy_name_list) > 8
+
+    # First dimension cannot be shard if input shape is (1, 4, 4, 16)
+    if input_shape != (1, 4, 4, 16):
+        assert 'S1S0 = S1R x RS0_0' in strategy_name_list
+        assert 'S0S1 = S0R x RS1_0' in strategy_name_list
+        assert 'S1R = S1S0 x S0R_0' in strategy_name_list
+        assert 'S0R = S0S1 x S1R_0' in strategy_name_list
+        assert 'S01R = S01R x RR_0' in strategy_name_list
 
     # SS = SR x RS
-    assert 'S0S1 = S0R x RS1_0' in strategy_name_list
     assert 'S0S1 = S0R x RS1_1' in strategy_name_list
     assert 'S0S1 = S0R x RS1_2' in strategy_name_list
-    assert 'S1S0 = S1R x RS0_0' in strategy_name_list
     assert 'S1S0 = S1R x RS0_1' in strategy_name_list
     assert 'S1S0 = S1R x RS0_2' in strategy_name_list
 
     # SR = SS x SR
-    assert 'S0R = S0S1 x S1R_0' in strategy_name_list
     assert 'S0R = S0S1 x S1R_1' in strategy_name_list
     assert 'S0R = S0S1 x S1R_2' in strategy_name_list
-    assert 'S1R = S1S0 x S0R_0' in strategy_name_list
     assert 'S1R = S1S0 x S0R_1' in strategy_name_list
     assert 'S1R = S1S0 x S0R_2' in strategy_name_list
 
@@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
     assert 'RS1 = RR x RS1' in strategy_name_list
 
     # S01R = S01R x RR
-    assert 'S01R = S01R x RR_0' in strategy_name_list
     assert 'S01R = S01R x RR_1' in strategy_name_list
     assert 'S01R = S01R x RR_2' in strategy_name_list
 
@@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
             assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
 
 
-# @parameterize('bias', [True, False])
+@parameterize('input_shape', [(1, 4, 4, 16), (4, 4, 4, 16)])
 @run_on_environment_flag(name='AUTO_PARALLEL')
 @pytest.mark.dist
 @rerun_if_address_is_in_use()
-def test_linear_handler(bias=False):
+def test_linear_handler(input_shape, bias=False):
     world_size = 4
-    run_func_module = partial(check_linear_module_handler, bias=bias, world_size=world_size, port=free_port())
+    run_func_module = partial(check_linear_module_handler,
+                              bias=bias,
+                              input_shape=input_shape,
+                              world_size=world_size,
+                              port=free_port())
     mp.spawn(run_func_module, nprocs=world_size)
-    run_func_function = partial(check_linear_function_handler, bias=bias, world_size=world_size, port=free_port())
+    run_func_function = partial(check_linear_function_handler,
+                                bias=bias,
+                                input_shape=input_shape,
+                                world_size=world_size,
+                                port=free_port())
     mp.spawn(run_func_function, nprocs=world_size)
 
 
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
similarity index 71%
rename from tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py
rename to tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
index f88d907c6..82accebdb 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
@@ -1,11 +1,14 @@
 from typing import Optional, Tuple, Union
 
 import torch
-# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
 import torch.nn as nn
 import transformers
 from torch.fx import GraphModule
-from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
+from transformers.models.gpt2.modeling_gpt2 import (
+    GPT2MLP,
+    BaseModelOutputWithPastAndCrossAttentions,
+    GPT2PreTrainedModel,
+)
 from transformers.pytorch_utils import Conv1D
 
 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
@@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
         return outputs    # hidden_states, present, (attentions, cross_attentions)
 
 
+class GPT2Model(GPT2PreTrainedModel):
+    _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
+
+    def __init__(self, config):
+        super().__init__(config)
+
+        self.embed_dim = config.hidden_size
+
+        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
+        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
+
+        self.drop = nn.Dropout(config.embd_pdrop)
+        self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        token_type_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.FloatTensor] = None,
+    ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+        input_shape = input_ids.size()
+        input_ids = input_ids.view(-1, input_shape[-1])
+        batch_size = input_ids.shape[0]
+
+        device = input_ids.device
+
+        token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+        past_length = 0
+        past_key_values = tuple([None] * len(self.h))
+
+        position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+        position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+        # GPT2Attention mask.
+        attention_mask = attention_mask.view(batch_size, -1)
+        attention_mask = attention_mask[:, None, None, :]
+        attention_mask = attention_mask.to(dtype=self.dtype)    # fp16 compatibility
+        attention_mask = (1.0 - attention_mask) * -10000.0
+
+        encoder_attention_mask = None
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_heads x N x N
+        # head_mask has shape n_layer x batch x n_heads x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        inputs_embeds = self.wte(input_ids)
+        position_embeds = self.wpe(position_ids)
+        # add_2
+        hidden_states = inputs_embeds + position_embeds
+
+        token_type_embeds = self.wte(token_type_ids)
+        hidden_states = hidden_states + token_type_embeds
+
+        # transformer_drop
+        hidden_states = self.drop(hidden_states)
+        # comment to run pipeline
+        # add_3
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        presents = None
+        all_self_attentions = None
+        all_cross_attentions = None
+        all_hidden_states = None
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+            outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
+            hidden_states = outputs[0]
+
+        hidden_states = self.ln_f(hidden_states)
+        # comment to run pipeline
+        hidden_states = hidden_states.view(output_shape)
+
+        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+                     if v is not None)
+
+
 @run_on_environment_flag(name='AUTO_PARALLEL')
-@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP])
+@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
 def test_self_attention_block(model_cls):
     config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
     if model_cls == GPT2MLP:
@@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
         input_sample = {
             'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
         }
-    else:
+    elif model_cls in (GPT2Attention, GPT2Block):
         input_sample = {
             'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
             'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
         }
+    else:
+        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+        input_sample = {k: v.to('meta') for k, v in kwargs.items()}
 
     graph = tracer.trace(root=model, meta_args=input_sample)