diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md
index 8fdf6be3b..7e6acb3d3 100644
--- a/examples/language/gpt/README.md
+++ b/examples/language/gpt/README.md
@@ -39,9 +39,15 @@ If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-
 For simplicity, the input data is randonly generated here.
 
 ## Training
-We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism.
-The other one uses Pipeline Parallelism Only.
-In the future, we are going merge them together and they can be used orthogonally to each other.
+We provide two stable solutions.
+One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model.
+The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP.
+
+We recommend using Gemini to qucikly run your model in a distributed manner.
+It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily.
+And use Titans as an advanced weapon to pursue a more extreme performance.
+Titans has included the some typical models, such as Vit and GPT.
+However, it requires some efforts to start if facing a new model structure.
 
 ### GeminiDPP/ZeRO + Tensor Parallelism
 ```bash
@@ -56,6 +62,11 @@ The `train_gpt_demo.py` provides three distributed plans, you can choose the pla
 - Pytorch DDP
 - Pytorch ZeRO
 
+### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism
+
+Titans provides a customized GPT model, which uses distributed operators as building blocks.
+In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
+You can switch parallel strategies using a config file.
 
 ## Performance
 
diff --git a/examples/language/gpt/titans/LICENSE b/examples/language/gpt/titans/LICENSE
new file mode 100644
index 000000000..261eeb9e9
--- /dev/null
+++ b/examples/language/gpt/titans/LICENSE
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/examples/language/gpt/titans/README.md b/examples/language/gpt/titans/README.md
new file mode 100644
index 000000000..fe1854c9f
--- /dev/null
+++ b/examples/language/gpt/titans/README.md
@@ -0,0 +1,48 @@
+# Run GPT With Colossal-AI
+
+## How to Prepare Webtext Dataset
+
+You can download the preprocessed sample dataset for this demo via our [Google Drive sharing link](https://drive.google.com/file/d/1QKI6k-e2gJ7XgS8yIpgPPiMmwiBP_BPE/view?usp=sharing).
+
+
+You can also avoid dataset preparation by using `--use_dummy_dataset` during running.
+
+## Run this Demo
+
+Use the following commands to install prerequisites.
+
+```bash
+# assuming using cuda 11.3
+pip install -r requirements.txt
+```
+
+Use the following commands to execute training.
+
+```Bash
+#!/usr/bin/env sh
+# if you want to use real dataset, then remove --use_dummy_dataset
+# export DATA=/path/to/small-gpt-dataset.json'
+
+# run on a single node
+colossalai run --nproc_per_node=<num_gpus> train_gpt.py --config configs/<config_file> --from_torch --use_dummy_dataset
+
+# run on multiple nodes with slurm
+colossalai run --nproc_per_node=<num_gpus> \
+   --master_addr <hostname> \
+   --master_port <port-number> \
+   --hosts <list-of-hostname-separated-by-comma> \
+   train_gpt.py \
+   --config configs/<config_file> \
+   --from_torch \
+   --use_dummy_dataset
+
+# run on multiple nodes with slurm
+srun python \
+   train_gpt.py \
+   --config configs/<config_file> \
+   --host <master_node> \
+   --use_dummy_dataset
+
+```
+
+You can set the `<config_file>` to any file in the `configs` folder. To simply get it running, you can start with `gpt_small_zero3_pp1d.py` on a single node first. You can view the explanations in the config file regarding how to change the parallel setting.
diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
new file mode 100644
index 000000000..8ef81cb0a
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT2_small_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 8
+NUM_EPOCHS = 10
+SEQ_LEN = 1024
+NUM_MICRO_BATCHES = 4
+HIDDEN_SIZE = 768
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+            optimizer_config=dict(initial_scale=2**16))
+
+optimizer = dict(
+    type=HybridAdam,
+    lr=0.00015,
+    weight_decay=1e-2,
+)
+
+model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+    pipeline=1,
+    tensor=dict(size=2, mode='1d'),
+)
diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
new file mode 100644
index 000000000..9f9816b30
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT3_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 192
+NUM_EPOCHS = 60
+SEQ_LEN = 2048
+NUM_MICRO_BATCHES = 192
+HIDDEN_SIZE = 12288
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+            optimizer_config=dict(initial_scale=2**16))
+
+optimizer = dict(
+    type=HybridAdam,
+    lr=0.00015,
+    weight_decay=1e-2,
+)
+
+model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+    pipeline=1,
+    tensor=dict(size=2, mode='1d'),    # for the current model implementation, mode can only be 1D or None
+)
diff --git a/examples/language/gpt/titans/model/__init__.py b/examples/language/gpt/titans/model/__init__.py
new file mode 100644
index 000000000..eec48ef89
--- /dev/null
+++ b/examples/language/gpt/titans/model/__init__.py
@@ -0,0 +1,3 @@
+from .embed import vocab_parallel_cross_entropy
+from .gpt1d import *
+from .pipeline_gpt1d import *
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
new file mode 100644
index 000000000..6369b9f8c
--- /dev/null
+++ b/examples/language/gpt/titans/model/embed.py
@@ -0,0 +1,599 @@
+import torch
+import torch.nn.init as init
+from torch import Tensor
+from torch import distributed as dist
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.parameter import Parameter
+
+from colossalai.context import ParallelMode, seed
+from colossalai.core import global_context as gpc
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
+from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
+from colossalai.nn.layer.utils import divide
+from colossalai.registry import LAYERS, LOSSES, MODELS
+from colossalai.utils import get_current_device
+
+
+class VocabParallelEmbedding(torch.nn.Module):
+    """Language model embeddings.
+
+    Arguments:
+        hidden_size: hidden size
+        vocab_size: vocabulary size
+        max_sequence_length: maximum size of sequence. This
+                             is used for positional embedding
+        embedding_dropout_prob: dropout probability for embeddings
+        init_method: weight initialization method
+        num_tokentypes: size of the token-type embeddings. 0 value
+                        will ignore this embedding
+    """
+
+    def __init__(self,
+                 hidden_size,
+                 vocab_size,
+                 max_sequence_length,
+                 embedding_dropout_prob,
+                 num_tokentypes=0,
+                 dtype=torch.float):
+        super(VocabParallelEmbedding, self).__init__()
+
+        self.hidden_size = hidden_size
+        self.num_tokentypes = num_tokentypes
+
+        # Word embeddings (parallel).
+        self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
+        self._word_embeddings_key = 'word_embeddings'
+
+        # Position embedding (serial).
+        self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
+        self._position_embeddings_key = 'position_embeddings'
+        # Initialize the position embeddings.
+        # self.init_method(self.position_embeddings.weight)
+
+        # Token type embedding.
+        # Add this as an optional field that can be added through
+        # method call so we can load a pretrain model without
+        # token types and add them as needed.
+        self._tokentype_embeddings_key = 'tokentype_embeddings'
+        if self.num_tokentypes > 0:
+            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
+            # Initialize the token-type embeddings.
+            # self.init_method(self.tokentype_embeddings.weight)
+        else:
+            self.tokentype_embeddings = None
+
+        # Embeddings dropout
+        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+    def zero_parameters(self):
+        """Zero out all parameters in embedding."""
+        self.word_embeddings.weight.data.fill_(0)
+        self.word_embeddings.weight.shared = True
+        self.position_embeddings.weight.data.fill_(0)
+        self.position_embeddings.weight.shared = True
+        if self.num_tokentypes > 0:
+            self.tokentype_embeddings.weight.data.fill_(0)
+            self.tokentype_embeddings.weight.shared = True
+
+    def add_tokentype_embeddings(self, num_tokentypes):
+        """Add token-type embedding. This function is provided so we can add
+        token-type embeddings in case the pretrained model does not have it.
+        This allows us to load the model normally and then add this embedding.
+        """
+        if self.tokentype_embeddings is not None:
+            raise Exception('tokentype embeddings is already initialized')
+        if torch.distributed.get_rank() == 0:
+            print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+        self.num_tokentypes = num_tokentypes
+        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+        # Initialize the token-type embeddings.
+        # self.init_method(self.tokentype_embeddings.weight)
+
+    def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+        # Embeddings.
+        if input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        words_embeddings = self.word_embeddings(input_ids)
+
+        if position_ids is not None:
+            position_ids = position_ids.view(-1, input_shape[-1])
+        if position_ids is None:
+            position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+        position_embeddings = self.position_embeddings(position_ids)
+
+        embeddings = words_embeddings + position_embeddings
+
+        # Dropout.
+        with seed(ParallelMode.TENSOR):
+            embeddings = self.embedding_dropout(embeddings)
+        return embeddings
+
+    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+        """For easy load."""
+
+        state_dict_ = {}
+        state_dict_[self._word_embeddings_key] \
+            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+        state_dict_[self._position_embeddings_key] \
+            = self.position_embeddings.state_dict(
+                destination, prefix, keep_vars)
+        if self.num_tokentypes > 0:
+            state_dict_[self._tokentype_embeddings_key] \
+                = self.tokentype_embeddings.state_dict(
+                    destination, prefix, keep_vars)
+
+        return state_dict_
+
+    def load_state_dict(self, state_dict, strict=True):
+        """Customized load."""
+
+        # Word embedding.
+        if self._word_embeddings_key in state_dict:
+            state_dict_ = state_dict[self._word_embeddings_key]
+        else:
+            # for backward compatibility.
+            state_dict_ = {}
+            for key in state_dict.keys():
+                if 'word_embeddings' in key:
+                    state_dict_[key.split('word_embeddings.')[1]] \
+                        = state_dict[key]
+        self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Position embedding.
+        if self._position_embeddings_key in state_dict:
+            state_dict_ = state_dict[self._position_embeddings_key]
+        else:
+            # for backward compatibility.
+            state_dict_ = {}
+            for key in state_dict.keys():
+                if 'position_embeddings' in key:
+                    state_dict_[key.split('position_embeddings.')[1]] \
+                        = state_dict[key]
+        self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Tokentype embedding.
+        if self.num_tokentypes > 0:
+            state_dict_ = {}
+            if self._tokentype_embeddings_key in state_dict:
+                state_dict_ = state_dict[self._tokentype_embeddings_key]
+            else:
+                # for backward compatibility.
+                for key in state_dict.keys():
+                    if 'tokentype_embeddings' in key:
+                        state_dict_[key.split('tokentype_embeddings.')[1]] \
+                            = state_dict[key]
+            if len(state_dict_.keys()) > 0:
+                self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+            else:
+                print('***WARNING*** expected tokentype embeddings in the '
+                      'checkpoint but could not find it',
+                      flush=True)
+
+
+class VocabParallelEmbedding1D(torch.nn.Module):
+    """Embedding parallelized in the vocabulary dimension.
+
+    This is mainly adapted from torch.nn.Embedding and all the default
+    values are kept.
+    Arguments:
+        num_embeddings: vocabulary size.
+        embedding_dim: size of hidden state.
+        init_method: method to initialize weights.
+    """
+
+    def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
+        super(VocabParallelEmbedding1D, self).__init__()
+        # Keep the input dimensions.
+        self.num_embeddings = num_embeddings
+        self.embedding_dim = embedding_dim
+        # Set the details for compatibility.
+        self.padding_idx = None
+        self.max_norm = None
+        self.norm_type = 2.
+        self.scale_grad_by_freq = False
+        self.sparse = False
+        self._weight = None
+        self.tensor_model_parallel_size = gpc.tensor_parallel_size
+        # Divide the weight matrix along the vocabulary dimension.
+        self.vocab_start_index, self.vocab_end_index = \
+            VocabUtility.vocab_range_from_global_vocab_size(
+                self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
+                self.tensor_model_parallel_size)
+        self.num_embeddings_per_partition = self.vocab_end_index - \
+            self.vocab_start_index
+
+        # Allocate weights and initialize.
+        factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+        self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
+        init.uniform_(self.weight, -1, 1)
+
+    def forward(self, input_):
+        if self.tensor_model_parallel_size > 1:
+            # Build the mask.
+            input_mask = (input_ < self.vocab_start_index) | \
+                         (input_ >= self.vocab_end_index)
+            # Mask the input.
+            masked_input = input_.clone() - self.vocab_start_index
+            masked_input[input_mask] = 0
+        else:
+            masked_input = input_
+            # Get the embeddings.
+        output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+                                      self.scale_grad_by_freq, self.sparse)
+        # Mask the output embedding.
+        if self.tensor_model_parallel_size > 1:
+            output_parallel[input_mask, :] = 0.0
+        # Reduce across all the model parallel GPUs.
+        output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
+        return output
+
+
+@LOSSES.register_module
+class vocab_parallel_cross_entropy(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, vocab_parallel_logits, target):
+        """Helper function for the cross entropy."""
+        vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
+        target = target[..., 1:].contiguous()
+        return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
+                                                target.view(-1))
+
+
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, vocab_parallel_logits, target):
+
+        # Maximum value along vocab dimension across all GPUs.
+        logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+        torch.distributed.all_reduce(logits_max,
+                                     op=torch.distributed.ReduceOp.MAX,
+                                     group=gpc.get_group(ParallelMode.PARALLEL_1D))
+        # Subtract the maximum value.
+        vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+        # Get the partition's vocab indices
+        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+        partition_vocab_size = vocab_parallel_logits.size()[-1]
+        rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+        world_size = gpc.tensor_parallel_size
+        vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
+
+        # Create a mask of valid vocab ids (1 means it needs to be masked).
+        target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+        masked_target = target.clone() - vocab_start_index
+        masked_target[target_mask] = 0
+
+        # Get predicted-logits = logits[target].
+        # For Simplicity, we convert logits to a 2-D tensor with size
+        # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+        logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+        masked_target_1d = masked_target.view(-1)
+        arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+        predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+        predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+        predicted_logits = predicted_logits_1d.view_as(target)
+        predicted_logits[target_mask] = 0.0
+        # All reduce is needed to get the chunks from other GPUs.
+        torch.distributed.all_reduce(predicted_logits,
+                                     op=torch.distributed.ReduceOp.SUM,
+                                     group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+        # Sum of exponential of logits along vocab dimension across all GPUs.
+        exp_logits = vocab_parallel_logits
+        torch.exp(vocab_parallel_logits, out=exp_logits)
+        sum_exp_logits = exp_logits.sum(dim=-1)
+        torch.distributed.all_reduce(sum_exp_logits,
+                                     op=torch.distributed.ReduceOp.SUM,
+                                     group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+        # Loss = log(sum(exp(logits))) - predicted-logit.
+        loss = torch.log(sum_exp_logits) - predicted_logits
+        loss = loss.mean()
+        # Store softmax, target-mask and masked-target for backward pass.
+        exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+        ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+        return loss
+
+    @staticmethod
+    def backward(ctx, grad_output):
+
+        # Retreive tensors from the forward path.
+        softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+        # All the inputs have softmax as their gradient.
+        grad_input = softmax
+        # For simplicity, work with the 2D gradient.
+        partition_vocab_size = softmax.size()[-1]
+        grad_2d = grad_input.view(-1, partition_vocab_size)
+
+        # Add the gradient from matching classes.
+        arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+        grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+        # Finally elementwise multiplication with the output gradients.
+        grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+        return grad_input, None
+
+
+class VocabUtility:
+    """Split the vocabulary into `world_size` chunks amd return the
+        first and last index of the vocabulary belonging to the `rank`
+        partition: Note that indices in [fist, last)"""
+
+    @staticmethod
+    def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
+        index_f = rank * per_partition_vocab_size
+        index_l = index_f + per_partition_vocab_size
+        return index_f, index_l
+
+    @staticmethod
+    def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
+        per_partition_vocab_size = divide(global_vocab_size, world_size)
+        return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
+
+
+class VocabParallelGPTLMHead1D(ParallelLayer):
+    """
+    Language model head that shares the same parameters with the embedding matrix.
+    """
+
+    def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
+        super().__init__()
+        if embed is not None:
+            self.head = embed
+        else:
+            self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = reduce_grad(x, ParallelMode.PARALLEL_1D)
+        x = F.linear(x, self.head.weight)
+        return x
+
+
+###################################
+
+
+class HiddenParallelEmbedding(torch.nn.Module):
+    """Language model embeddings.
+
+    Arguments:
+        hidden_size: hidden size
+        vocab_size: vocabulary size
+        max_sequence_length: maximum size of sequence. This
+                             is used for positional embedding
+        embedding_dropout_prob: dropout probability for embeddings
+        init_method: weight initialization method
+        num_tokentypes: size of the token-type embeddings. 0 value
+                        will ignore this embedding
+    """
+
+    def __init__(
+        self,
+        hidden_size,
+        vocab_size,
+        max_sequence_length,
+        embedding_dropout_prob,
+        dtype=torch.float,
+        padding_idx: int = 0,
+        num_tokentypes=0,
+    ):
+        super(HiddenParallelEmbedding, self).__init__()
+
+        self.hidden_size = hidden_size
+        self.num_tokentypes = num_tokentypes
+
+        # Word embeddings (parallel).
+        self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+        self._word_embeddings_key = 'word_embeddings'
+
+        # Position embedding (serial).
+        self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
+        self._position_embeddings_key = 'position_embeddings'
+        # Initialize the position embeddings.
+        # self.init_method(self.position_embeddings.weight)
+
+        # Token type embedding.
+        # Add this as an optional field that can be added through
+        # method call so we can load a pretrain model without
+        # token types and add them as needed.
+        self._tokentype_embeddings_key = 'tokentype_embeddings'
+        if self.num_tokentypes > 0:
+            self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
+            # Initialize the token-type embeddings.
+            # self.init_method(self.tokentype_embeddings.weight)
+        else:
+            self.tokentype_embeddings = None
+
+        # Embeddings dropout
+        self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+    def zero_parameters(self):
+        """Zero out all parameters in embedding."""
+        self.word_embeddings.weight.data.fill_(0)
+        self.word_embeddings.weight.shared = True
+        self.position_embeddings.weight.data.fill_(0)
+        self.position_embeddings.weight.shared = True
+        if self.num_tokentypes > 0:
+            self.tokentype_embeddings.weight.data.fill_(0)
+            self.tokentype_embeddings.weight.shared = True
+
+    def add_tokentype_embeddings(self, num_tokentypes):
+        """Add token-type embedding. This function is provided so we can add
+        token-type embeddings in case the pretrained model does not have it.
+        This allows us to load the model normally and then add this embedding.
+        """
+        if self.tokentype_embeddings is not None:
+            raise Exception('tokentype embeddings is already initialized')
+        if torch.distributed.get_rank() == 0:
+            print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+        self.num_tokentypes = num_tokentypes
+        self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+        # Initialize the token-type embeddings.
+        # self.init_method(self.tokentype_embeddings.weight)
+
+    def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+        if input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        words_embeddings = self.word_embeddings(input_ids)
+
+        if position_ids is not None:
+            position_ids = position_ids.view(-1, input_shape[-1])
+        if position_ids is None:
+            position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+        position_embeddings = self.position_embeddings(position_ids)
+
+        embeddings = words_embeddings + position_embeddings
+
+        # Dropout.
+        with seed(ParallelMode.TENSOR):
+            embeddings = self.embedding_dropout(embeddings)
+        return embeddings
+
+    def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+        """For easy load."""
+
+        state_dict_ = {}
+        state_dict_[self._word_embeddings_key] \
+            = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+        state_dict_[self._position_embeddings_key] \
+            = self.position_embeddings.state_dict(
+                destination, prefix, keep_vars)
+        if self.num_tokentypes > 0:
+            state_dict_[self._tokentype_embeddings_key] \
+                = self.tokentype_embeddings.state_dict(
+                    destination, prefix, keep_vars)
+
+        return state_dict_
+
+    def load_state_dict(self, state_dict, strict=True):
+        """Customized load."""
+
+        # Word embedding.
+        if self._word_embeddings_key in state_dict:
+            state_dict_ = state_dict[self._word_embeddings_key]
+        else:
+            # for backward compatibility.
+            state_dict_ = {}
+            for key in state_dict.keys():
+                if 'word_embeddings' in key:
+                    state_dict_[key.split('word_embeddings.')[1]] \
+                        = state_dict[key]
+        self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Position embedding.
+        if self._position_embeddings_key in state_dict:
+            state_dict_ = state_dict[self._position_embeddings_key]
+        else:
+            # for backward compatibility.
+            state_dict_ = {}
+            for key in state_dict.keys():
+                if 'position_embeddings' in key:
+                    state_dict_[key.split('position_embeddings.')[1]] \
+                        = state_dict[key]
+        self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+        # Tokentype embedding.
+        if self.num_tokentypes > 0:
+            state_dict_ = {}
+            if self._tokentype_embeddings_key in state_dict:
+                state_dict_ = state_dict[self._tokentype_embeddings_key]
+            else:
+                # for backward compatibility.
+                for key in state_dict.keys():
+                    if 'tokentype_embeddings' in key:
+                        state_dict_[key.split('tokentype_embeddings.')[1]] \
+                            = state_dict[key]
+            if len(state_dict_.keys()) > 0:
+                self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+            else:
+                print('***WARNING*** expected tokentype embeddings in the '
+                      'checkpoint but could not find it',
+                      flush=True)
+
+
+class HiddenParallelEmbedding1D(torch.nn.Module):
+    """Embedding parallelized in the vocabulary dimension.
+
+    This is mainly adapted from torch.nn.Embedding and all the default
+    values are kept.
+    Arguments:
+        num_embeddings: vocabulary size.
+        embedding_dim: size of hidden state.
+        init_method: method to initialize weights.
+    """
+
+    def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
+        super(HiddenParallelEmbedding1D, self).__init__()
+        # Keep the input dimensions.
+        self.num_embeddings = num_embeddings
+        self.embedding_dim = embedding_dim
+        embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
+        # Set the details for compatibility.
+        self.padding_idx = padding_idx
+        self.max_norm = None
+        self.norm_type = 2.
+        self.scale_grad_by_freq = False
+        self.sparse = False
+        self._weight = None
+
+        # Allocate weights and initialize.
+        factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+        self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
+        init.uniform_(self.weight, -1, 1)
+
+    def forward(self, input_):
+
+        # Get the embeddings.
+        output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+                                      self.scale_grad_by_freq, self.sparse)
+
+        # Reduce across all the model parallel GPUs.
+        output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
+        return output
+
+
+@LAYERS.register_module
+class HiddenParallelGPTLMHead1D(ParallelLayer):
+    """
+    Language model head that shares the same parameters with the embedding matrix.
+    """
+
+    def __init__(
+        self,
+        embed=None,
+        embed_dim=None,
+        vocab_size=None,
+        dtype=None,
+    ):
+        super().__init__()
+        if embed is not None:
+            self.head = embed
+            self.synced_embed = True
+        else:
+            # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+            # (hidden_size/q, vocab_size)
+            self.synced_embed = False
+            self.head = Linear1D_Row(in_features=embed_dim,
+                                     out_features=vocab_size,
+                                     bias=False,
+                                     dtype=dtype,
+                                     parallel_input=False)
+
+    def forward(self, x: Tensor) -> Tensor:
+        if self.synced_embed:
+            x = F.linear(x, self.head.weight)
+        else:
+            x = self.head(x)
+
+        return x
diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py
new file mode 100644
index 000000000..2edd03606
--- /dev/null
+++ b/examples/language/gpt/titans/model/gpt1d.py
@@ -0,0 +1,349 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import math
+
+import torch
+from torch import Tensor
+from torch import nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.core import global_context as gpc
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.utils import ACT2FN, divide
+from colossalai.utils import checkpoint
+from colossalai.utils.activation_checkpoint import checkpoint
+
+__all__ = [
+    'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
+]
+
+
+class GPTMLP1D(ParallelLayer):
+
+    def __init__(
+        self,
+        in_features: int,
+        mlp_ratio: int,
+        act_func: str = 'gelu',
+        dropout_prob: float = 0.,
+        dtype=None,
+        checkpoint: bool = False,
+        skip_bias_add: bool = False,
+    ):
+        super().__init__()
+
+        self.in_features = in_features
+        self.mlp_ratio = mlp_ratio
+        self.checkpoint = checkpoint
+        self.skip_bias_add = skip_bias_add
+
+        self.act = ACT2FN[act_func]
+        skip_dense_1_add_bias = False
+
+        # Project to mlp_ratio * h.
+        self.dense_1 = Linear1D_Col(
+            self.in_features,
+            int(self.mlp_ratio * self.in_features),
+            dtype=dtype,
+            gather_output=False,
+            skip_bias_add=skip_dense_1_add_bias,
+        )
+
+        # Project back to h.
+        self.dense_2 = Linear1D_Row(
+            int(self.mlp_ratio * self.in_features),
+            self.in_features,
+            dtype=dtype,
+            parallel_input=True,
+        )
+
+        self.dropout = col_nn.Dropout(dropout_prob)
+
+    def _forward(self, hidden_states: Tensor) -> Tensor:
+        intermediate_output = self.dense_1(hidden_states)
+        intermediate_output = self.act(intermediate_output)
+
+        output = self.dense_2(intermediate_output)
+        output = self.dropout(output)
+        return output
+
+    def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+        return checkpoint(self._forward, False, hidden_states)
+
+    def forward(self, hidden_states: Tensor) -> Tensor:
+        if self.checkpoint:
+            return self._checkpoint_forward(hidden_states)
+        else:
+            return self._forward(hidden_states)
+
+
+class GenericGPTSelfAttention1D(ParallelLayer):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        attention_dropout_prob: float,
+        hidden_dropout_prob: float,
+        dtype=None,
+        checkpoint: bool = False,
+        max_position_embeddings=1024,
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.attention_head_size = divide(hidden_size, num_attention_heads)
+        self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
+        self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
+        self.checkpoint = checkpoint
+        self.query_key_value = Linear1D_Col(
+            hidden_size,
+            3 * hidden_size,
+            dtype=dtype,
+        )
+        self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
+        self.dense = Linear1D_Row(
+            hidden_size,
+            hidden_size,
+            dtype=dtype,
+            parallel_input=True,
+        )
+        self.dropout = col_nn.Dropout(hidden_dropout_prob)
+
+    def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+        raise NotImplementedError
+
+    def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+        query_key_value = self.query_key_value(hidden_states)
+        new_qkv_shape = query_key_value.shape[:-1] + \
+            (self.num_attention_heads_per_partition, 3 * self.attention_head_size)
+        query_key_value = query_key_value.view(new_qkv_shape)
+        query_key_value = query_key_value.permute((0, 2, 1, 3))
+        query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
+
+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+        attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
+
+        attention_scores = attention_scores.type(value_layer.dtype)
+
+        attention_probs = self.attention_dropout(attention_scores)
+
+        context_layer = torch.matmul(attention_probs, value_layer)
+        context_layer = context_layer.transpose(1, 2)
+        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+        context_layer = context_layer.reshape(new_context_layer_shape)
+        output = self.dense(context_layer)
+        output = self.dropout(output)
+
+        return output
+
+    def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+        return checkpoint(self._forward, False, hidden_states, attention_mask)
+
+    def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+        if self.checkpoint:
+            return self._checkpoint_forward(hidden_states, attention_mask)
+        else:
+            return self._forward(hidden_states, attention_mask)
+
+
+class GPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_attention_heads: int,
+                 attention_dropout_prob: float,
+                 hidden_dropout_prob: float,
+                 dtype=None,
+                 checkpoint: bool = False,
+                 max_position_embeddings=1024):
+        super().__init__(hidden_size,
+                         num_attention_heads,
+                         attention_dropout_prob,
+                         hidden_dropout_prob,
+                         dtype=dtype,
+                         checkpoint=checkpoint,
+                         max_position_embeddings=max_position_embeddings)
+        self.softmax = nn.Softmax(dim=-1)
+        max_positions = max_position_embeddings
+        self.register_buffer(
+            "bias",
+            torch.tril(torch.ones((max_positions, max_positions),
+                                  dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
+        )
+        self.register_buffer("masked_bias", torch.tensor(-1e4))
+
+    def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+        # causal mask
+        query_length, key_length = query_layer.size(-2), key_layer.size(-2)
+        causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
+        attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
+        if attention_mask is not None:
+            # Apply the attention mask
+            attention_scores = attention_scores + attention_mask
+        attention_scores = self.softmax(attention_scores)
+        return attention_scores
+
+
+class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_attention_heads: int,
+                 attention_dropout_prob: float,
+                 hidden_dropout_prob: float,
+                 dtype=None,
+                 checkpoint: bool = False,
+                 max_position_embeddings=1024):
+        super().__init__(hidden_size,
+                         num_attention_heads,
+                         attention_dropout_prob,
+                         hidden_dropout_prob,
+                         dtype=dtype,
+                         checkpoint=checkpoint,
+                         max_position_embeddings=max_position_embeddings)
+        self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
+                                                    input_in_bf16=False,
+                                                    attn_mask_type=AttnMaskType.causal,
+                                                    scaled_masked_softmax_fusion=True,
+                                                    mask_func=None,
+                                                    softmax_in_fp32=True,
+                                                    scale=math.sqrt(self.attention_head_size))
+
+    def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+        return self.softmax(attention_scores, attention_mask)
+
+
+class GenericGPTTransformerLayer1D(ParallelLayer):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_attention_heads: int,
+                 act_func: str = 'gelu',
+                 mlp_ratio: float = 4.0,
+                 attention_dropout_prob: float = 0.,
+                 hidden_dropout_prob: float = 0.,
+                 dtype=None,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 1e-5,
+                 apply_post_layer_norm: bool = False,
+                 attention=None,
+                 layer_norm=None):
+        super().__init__()
+        self.checkpoint = checkpoint
+        self.dtype = dtype
+        self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+        self.apply_post_layer_norm = apply_post_layer_norm
+        self.attention = attention(
+            hidden_size=hidden_size,
+            num_attention_heads=num_attention_heads,
+            attention_dropout_prob=attention_dropout_prob,
+            hidden_dropout_prob=hidden_dropout_prob,
+            dtype=dtype,
+            max_position_embeddings=max_position_embeddings,
+            checkpoint=False,
+        )
+
+        self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+        self.mlp = GPTMLP1D(
+            in_features=hidden_size,
+            dropout_prob=hidden_dropout_prob,
+            act_func=act_func,
+            mlp_ratio=mlp_ratio,
+            dtype=dtype,
+            checkpoint=False,
+        )
+
+    def _forward(self, hidden_states, attention_mask) -> Tensor:
+        if not self.apply_post_layer_norm:
+            residual = hidden_states
+        hidden_states = self.norm1(hidden_states)
+        if self.apply_post_layer_norm:
+            residual = hidden_states
+        attention_output = self.attention(hidden_states, attention_mask)
+        hidden_states = residual + attention_output
+
+        if not self.apply_post_layer_norm:
+            residual = hidden_states
+        hidden_states = self.norm2(hidden_states)
+        if self.apply_post_layer_norm:
+            residual = hidden_states
+        feed_forward_hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + feed_forward_hidden_states
+
+        output = (hidden_states, attention_mask)
+        return output
+
+    def forward(self, hidden_states, attention_mask):
+        if self.checkpoint:
+            return checkpoint(self._forward, False, hidden_states, attention_mask)
+        else:
+            return self._forward(hidden_states, attention_mask)
+
+
+class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_attention_heads: int,
+                 act_func: str = 'gelu',
+                 mlp_ratio: float = 4,
+                 attention_dropout_prob: float = 0,
+                 hidden_dropout_prob: float = 0,
+                 dtype=None,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 0.00001,
+                 apply_post_layer_norm: bool = False):
+        attention = GPTSelfAttention1D
+        layer_norm = nn.LayerNorm
+        super().__init__(hidden_size,
+                         num_attention_heads,
+                         act_func=act_func,
+                         mlp_ratio=mlp_ratio,
+                         attention_dropout_prob=attention_dropout_prob,
+                         hidden_dropout_prob=hidden_dropout_prob,
+                         dtype=dtype,
+                         checkpoint=checkpoint,
+                         max_position_embeddings=max_position_embeddings,
+                         layer_norm_epsilon=layer_norm_epsilon,
+                         apply_post_layer_norm=apply_post_layer_norm,
+                         attention=attention,
+                         layer_norm=layer_norm)
+
+
+class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+    def __init__(self,
+                 hidden_size: int,
+                 num_attention_heads: int,
+                 act_func: str = 'gelu',
+                 mlp_ratio: float = 4,
+                 attention_dropout_prob: float = 0,
+                 hidden_dropout_prob: float = 0,
+                 dtype=None,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 0.00001,
+                 apply_post_layer_norm: bool = False):
+        attention = FusedGPTSelfAttention1D
+        layer_norm = kernel.LayerNorm
+        super().__init__(hidden_size,
+                         num_attention_heads,
+                         act_func=act_func,
+                         mlp_ratio=mlp_ratio,
+                         attention_dropout_prob=attention_dropout_prob,
+                         hidden_dropout_prob=hidden_dropout_prob,
+                         dtype=dtype,
+                         checkpoint=checkpoint,
+                         max_position_embeddings=max_position_embeddings,
+                         layer_norm_epsilon=layer_norm_epsilon,
+                         apply_post_layer_norm=apply_post_layer_norm,
+                         attention=attention,
+                         layer_norm=layer_norm)
diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py
new file mode 100644
index 000000000..30180285b
--- /dev/null
+++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py
@@ -0,0 +1,322 @@
+import inspect
+
+# import model_zoo.gpt.gpt as col_gpt
+import titans.model.gpt.gpt as col_gpt
+import torch
+import torch.nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.pipeline.utils import partition_uniform
+
+from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
+from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
+
+__all__ = [
+    'GPT2_small_pipeline_1D',
+    'GPT2_exlarge_pipeline_1D',
+    'GPT3_pipeline_1D',
+    'GPT2_exlarge_pipeline_hybrid',
+    'GPT2_small_pipeline_hybrid',
+    'GPT3_pipeline_hybrid',
+]
+
+
+class GenericPipelineGPT(nn.Module):
+
+    def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
+        super().__init__()
+        self.embedding = embedding
+        self.blocks = blocks
+        self.norm = norm
+        self.head = head
+        assert blocks is not None
+        if norm is not None or head is not None:
+            assert norm is not None and head is not None
+
+    def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+        if self.embedding is not None:
+            hidden_states = self.embedding(input_ids=input_ids)
+        batch_size = hidden_states.shape[0]
+        attention_mask = attention_mask.view(batch_size, -1)
+        attention_mask = attention_mask[:, None, None, :]
+        attention_mask = attention_mask.to(dtype=hidden_states.dtype)    # fp16 compatibility
+        attention_mask = (1.0 - attention_mask) * -10000.0
+        for block in self.blocks:
+            hidden_states, attention_mask = block(hidden_states, attention_mask)
+        if self.norm is not None:
+            hidden_states = self.head(self.norm(hidden_states))
+        return hidden_states
+
+
+class PipelineGPT1D(GenericPipelineGPT):
+
+    def __init__(self,
+                 num_layers: int = 12,
+                 hidden_size: int = 768,
+                 num_attention_heads: int = 12,
+                 vocab_size: int = 50304,
+                 embed_drop_rate: float = 0.,
+                 act_func: str = 'gelu',
+                 mlp_ratio: int = 4.0,
+                 attn_drop_rate: float = 0.,
+                 drop_rate: float = 0.,
+                 dtype: torch.dtype = torch.float,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 1e-5,
+                 apply_post_layer_norm: bool = False,
+                 first: bool = False,
+                 last: bool = False,
+                 embed_split_hidden=False):
+        embedding = None
+        norm = None
+        head = None
+        embed_cls = VocabParallelEmbedding
+        head_cls = VocabParallelGPTLMHead1D
+        if embed_split_hidden:
+            embed_cls = HiddenParallelEmbedding
+            head_cls = HiddenParallelGPTLMHead1D
+        if first:
+            embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+        blocks = nn.ModuleList([
+            GPTTransformerLayer1D(hidden_size,
+                                  num_attention_heads,
+                                  act_func=act_func,
+                                  mlp_ratio=mlp_ratio,
+                                  attention_dropout_prob=attn_drop_rate,
+                                  hidden_dropout_prob=drop_rate,
+                                  dtype=dtype,
+                                  checkpoint=checkpoint,
+                                  max_position_embeddings=max_position_embeddings,
+                                  layer_norm_epsilon=layer_norm_epsilon,
+                                  apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+        ])
+        if last:
+            norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+            head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+        super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+class FusedPipelineGPT1D(GenericPipelineGPT):
+
+    def __init__(self,
+                 num_layers: int = 12,
+                 hidden_size: int = 768,
+                 num_attention_heads: int = 12,
+                 vocab_size: int = 50304,
+                 embed_drop_rate: float = 0.,
+                 act_func: str = 'gelu',
+                 mlp_ratio: int = 4.0,
+                 attn_drop_rate: float = 0.,
+                 drop_rate: float = 0.,
+                 dtype: torch.dtype = torch.float,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 1e-5,
+                 apply_post_layer_norm: bool = False,
+                 first: bool = False,
+                 last: bool = False,
+                 embed_split_hidden=False):
+        embedding = None
+        norm = None
+        head = None
+        embed_cls = VocabParallelEmbedding
+        head_cls = VocabParallelGPTLMHead1D
+        if embed_split_hidden:
+            embed_cls = HiddenParallelEmbedding
+            head_cls = HiddenParallelGPTLMHead1D
+        if first:
+            embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+        blocks = nn.ModuleList([
+            FusedGPTTransformerLayer1D(hidden_size,
+                                       num_attention_heads,
+                                       act_func=act_func,
+                                       mlp_ratio=mlp_ratio,
+                                       attention_dropout_prob=attn_drop_rate,
+                                       hidden_dropout_prob=drop_rate,
+                                       dtype=dtype,
+                                       checkpoint=checkpoint,
+                                       max_position_embeddings=max_position_embeddings,
+                                       layer_norm_epsilon=layer_norm_epsilon,
+                                       apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+        ])
+        if last:
+            norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+            head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+        super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+    def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+        if self.embedding is not None:
+            hidden_states = self.embedding(input_ids=input_ids)
+        attention_mask = attention_mask.to(dtype=hidden_states.dtype)    # fp16 compatibility
+        for block in self.blocks:
+            hidden_states, attention_mask = block(hidden_states, attention_mask)
+        if self.norm is not None:
+            hidden_states = self.head(self.norm(hidden_states))
+        return hidden_states
+
+
+class PipelineGPTHybrid(GenericPipelineGPT):
+
+    def __init__(self,
+                 num_layers: int = 12,
+                 hidden_size: int = 768,
+                 num_attention_heads: int = 12,
+                 vocab_size: int = 50304,
+                 embed_drop_rate: float = 0.,
+                 act_func: str = 'gelu',
+                 mlp_ratio: int = 4,
+                 attn_drop_rate: float = 0.,
+                 drop_rate: float = 0.,
+                 dtype: torch.dtype = torch.float,
+                 checkpoint: bool = False,
+                 max_position_embeddings: int = 1024,
+                 layer_norm_epsilon: float = 1e-5,
+                 apply_post_layer_norm: bool = False,
+                 first: bool = False,
+                 last: bool = False,
+                 embed_split_hidden=False):
+        embedding = None
+        norm = None
+        head = None
+        if first:
+            embedding = col_gpt.GPTEmbedding(hidden_size,
+                                             vocab_size,
+                                             max_position_embeddings,
+                                             dropout=embed_drop_rate,
+                                             dtype=dtype)
+        blocks = nn.ModuleList([
+            col_gpt.GPTBlock(hidden_size,
+                             num_attention_heads,
+                             mlp_ratio=mlp_ratio,
+                             attention_dropout=attn_drop_rate,
+                             dropout=drop_rate,
+                             dtype=dtype,
+                             checkpoint=checkpoint,
+                             activation=nn.functional.gelu) for _ in range(num_layers)
+        ])
+        if last:
+            norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+            # head = col_gpt.GPTLMHead(vocab_size=vocab_size,
+            #                          hidden_size=hidden_size,
+            #                          dtype=dtype,
+            #                          bias=False)
+            head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
+        super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+def _filter_kwargs(func, kwargs):
+    sig = inspect.signature(func)
+    return {k: v for k, v in kwargs.items() if k in sig.parameters}
+
+
+def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+    logger = get_dist_logger()
+
+    if gpc.is_initialized(ParallelMode.PIPELINE):
+        pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+        pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+    else:
+        pipeline_size = 1
+        pipeline_rank = 0
+    rank = gpc.get_global_rank()
+
+    if pipeline_size > 1:
+        wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
+    else:
+        wrapper = None
+    parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+    models = []
+    for start, end in parts:
+        kwargs['num_layers'] = end - start
+        kwargs['first'] = start == 0
+        kwargs['last'] = end == num_layers
+        logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+        chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
+
+        if wrapper is not None:
+            if start == 0:
+                wrapper.register_module(chunk.embedding.word_embeddings)
+            elif end == num_layers:
+                wrapper.register_module(chunk.head)
+        models.append(chunk)
+    if len(models) == 1:
+        model = models[0]
+    else:
+        model = nn.ModuleList(models)
+
+    numel = 0
+    for _, param in model.named_parameters(recurse=True):
+        numel += param.numel()
+    logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
+    return model
+
+
+def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
+    model = FusedPipelineGPT1D if fused else PipelineGPT1D
+    return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
+
+
+def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+    return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
+
+
+def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+    cfg = dict(hidden_size=768,
+               num_attention_heads=12,
+               checkpoint=checkpoint,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+    cfg = dict(hidden_size=1600,
+               num_attention_heads=32,
+               checkpoint=checkpoint,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
+
+
+def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+    cfg = dict(hidden_size=12288,
+               num_attention_heads=96,
+               checkpoint=checkpoint,
+               max_position_embeddings=2048,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+    cfg = dict(hidden_size=1600,
+               num_attention_heads=32,
+               checkpoint=checkpoint,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
+
+
+def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+    cfg = dict(hidden_size=768,
+               num_attention_heads=12,
+               checkpoint=checkpoint,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
+
+
+def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+    cfg = dict(hidden_size=12288,
+               num_attention_heads=96,
+               checkpoint=checkpoint,
+               max_position_embeddings=2048,
+               dtype=dtype,
+               embed_split_hidden=embed_split_hidden)
+    return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
diff --git a/examples/language/gpt/titans/requirements.txt b/examples/language/gpt/titans/requirements.txt
new file mode 100644
index 000000000..64ff7a4ab
--- /dev/null
+++ b/examples/language/gpt/titans/requirements.txt
@@ -0,0 +1,4 @@
+torch==1.12.1
+titans==0.0.7
+colossalai==0.2.0+torch1.12cu11.3
+-f https://release.colossalai.org
diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh
new file mode 100644
index 000000000..157bd377a
--- /dev/null
+++ b/examples/language/gpt/titans/run.sh
@@ -0,0 +1,2 @@
+export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
+colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch
diff --git a/examples/language/gpt/titans/test_ci.sh b/examples/language/gpt/titans/test_ci.sh
new file mode 100644
index 000000000..7cb24c1a4
--- /dev/null
+++ b/examples/language/gpt/titans/test_ci.sh
@@ -0,0 +1 @@
+colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch --use_dummy_dataset
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
new file mode 100644
index 000000000..1380b4b3a
--- /dev/null
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -0,0 +1,148 @@
+import contextlib
+import os
+
+import torch
+import torch.nn as nn
+from titans.model.gpt import GPTLMLoss
+
+import colossalai
+import colossalai.utils as utils
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn import LinearWarmupLR
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
+from colossalai.utils.timer import MultiTimer
+from colossalai.zero.init_ctx import ZeroInitContext
+
+
+def calc_local_model_size(model: torch.nn.Module):
+    numel_per_device = 0
+    for p in model.parameters():
+        numel_per_device += p.numel()
+    return numel_per_device
+
+
+VOCAB_SIZE = 50257
+
+
+def main():
+    parser = colossalai.get_default_parser()
+    parser.add_argument('--from_torch', default=False, action='store_true')
+    parser.add_argument('--use_dummy_dataset', default=True, action='store_true')
+    args = parser.parse_args()
+    disable_existing_loggers()
+    if args.from_torch:
+        colossalai.launch_from_torch(config=args.config)
+    else:
+        colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
+    logger = get_dist_logger()
+
+    if not args.use_dummy_dataset:
+        data_path = os.environ['DATA']
+        logger.info(f'Build data loader from path {data_path}', ranks=[0])
+        from dataset.webtext import WebtextDataset
+        train_ds = WebtextDataset(os.environ['DATA'], seq_len=gpc.config.SEQ_LEN)
+        train_dataloader = utils.get_dataloader(train_ds,
+                                                seed=42,
+                                                batch_size=gpc.config.BATCH_SIZE,
+                                                pin_memory=True,
+                                                shuffle=True,
+                                                drop_last=True)
+    else:
+        # build a dummy train_dataloader
+        logger.info('Build data loader using dummy data', ranks=[0])
+
+        def get_data(batch_size, seq_len, vocab_size):
+            input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
+            attention_mask = torch.ones_like(input_ids)
+            return input_ids, attention_mask
+
+        # 10 iterations
+        input_ids, attn_mask = get_data(gpc.config.BATCH_SIZE * 10, gpc.config.SEQ_LEN, VOCAB_SIZE)
+        from torch.utils.data import DataLoader, Dataset
+
+        class TextSamplerDataset(Dataset):
+
+            def __init__(self, data, seq_len):
+                super().__init__()
+                self.data = data
+                self.seq_len = seq_len
+
+            def __getitem__(self, index):
+                rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
+                full_seq = self.data[rand_start:rand_start + self.seq_len + 1].long()
+                return full_seq.cuda()
+
+            def __len__(self):
+                return self.data.size(0) // self.seq_len
+
+        def cycle(loader):
+            while True:
+                for data in loader:
+                    yield data
+
+        train_dataset = TextSamplerDataset(input_ids, gpc.config.SEQ_LEN)
+        train_dataloader = DataLoader(train_dataset, batch_size=gpc.config.BATCH_SIZE)
+
+    logger.info('Build model', ranks=[0])
+    use_pipeline = is_using_pp()
+    use_interleaved = hasattr(gpc.config.model, 'num_chunks')
+    use_zero3 = hasattr(gpc.config, 'zero')
+    ctx = contextlib.nullcontext()
+    if use_zero3:
+        ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
+                              shard_strategy=gpc.config.zero.model_config.shard_strategy,
+                              shard_param=True)
+    with ctx:
+        model = gpc.config.model.pop('type')(**gpc.config.model)
+    if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
+        model = nn.ModuleList([model])
+
+    if use_zero3:
+        numel = ctx.model_numel_tensor.item()
+    else:
+        numel = calc_local_model_size(model)
+
+    tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
+        * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
+
+    criterion = getattr(gpc.config, 'loss_fn', None)
+    if criterion is not None:
+        criterion = criterion.type()
+    else:
+        criterion = GPTLMLoss()
+    logger.info('Build optimizer', ranks=[0])
+    optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
+    lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
+    engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
+                                                                      optimizer,
+                                                                      criterion,
+                                                                      train_dataloader=train_dataloader,
+                                                                      lr_scheduler=lr_scheduler)
+    global_batch_size = gpc.config.BATCH_SIZE * \
+        gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
+    logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
+    timier = MultiTimer()
+    trainer = Trainer(engine=engine, logger=logger, timer=timier)
+    hook_list = [
+        hooks.LossHook(),
+        hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+        hooks.LogMetricByEpochHook(logger),
+        hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
+        hooks.LogMetricByStepHook(),
+        hooks.LogMemoryByEpochHook(logger),
+    # hooks.LogMemoryByEpochHook(logger),
+    # hooks.LogTimingByEpochHook(timer, logger),
+    ]
+    trainer.fit(train_dataloader=train_dataloader,
+                epochs=gpc.config.NUM_EPOCHS,
+                test_interval=1,
+                hooks=hook_list,
+                display_progress=True,
+                return_output_label=False)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index a334ea951..2f012780d 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -11,13 +11,12 @@ import tqdm
 from packaging import version
 from palm_pytorch import PaLM
 from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
-from torch.nn import functional as F
 from torch.utils.data import DataLoader, Dataset
 
 import colossalai
 from colossalai.logging import disable_existing_loggers, get_dist_logger
 from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
-from colossalai.nn.parallel import GeminiDDP, ZeroDDP
+from colossalai.nn.parallel import ZeroDDP
 from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
 from colossalai.utils import MultiTimer, get_current_device
 from colossalai.utils.model.colo_init_context import ColoInitContext