diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py
index 1f5345015..e071563c0 100644
--- a/colossalai/lazy/lazy_init.py
+++ b/colossalai/lazy/lazy_init.py
@@ -6,6 +6,7 @@ import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch import Tensor
+from torch.nn import Parameter
 from torch.utils._pytree import tree_map
 
 from colossalai._analyzer._subclasses import MetaTensor
@@ -99,8 +100,11 @@ def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
     Returns:
         torch.Tensor: the converted tensor
     """
-    cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
+    cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
     tensor.__class__ = cls_to_become
+    if cls_to_become is Parameter:
+        # to fit UninitializedParameter
+        delattr(tensor, '_is_param')
     tensor.data = target
     tensor.requires_grad = target.requires_grad
     # subclass of torch.Tensor does not have tolist() method
@@ -198,10 +202,10 @@ class LazyTensor(torch.Tensor):
     def clean(self) -> None:
         """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
         """
-        self._factory_method = None
-        self._op_buffer = None
-        self._materialized_data = None
-        self._meta_data = None
+        delattr(self, '_factory_method')
+        delattr(self, '_op_buffer')
+        delattr(self, '_materialized_data')
+        delattr(self, '_meta_data')
 
     @staticmethod
     def _replace_with_materialized(x):
@@ -350,20 +354,19 @@ class LazyTensor(torch.Tensor):
         def factory_fn():
             # if self is materialized, return self
             new_tensor = self.materialize() if type(self) is LazyTensor else self
-            copied = new_tensor.detach().clone()
-            if new_tensor.requires_grad:
-                copied.requires_grad_()
-            return copied
+            return _copy_tensor(new_tensor, new_tensor.requires_grad)
 
         if self._materialized_data is not None:
             # self is early materialized
-            copied = self._materialized_data.detach().clone()
-            if self.requires_grad:
-                copied.requires_grad_()
+            copied = _copy_tensor(self._materialized_data, self.requires_grad)
             target = LazyTensor(lambda: None, concrete_data=copied)
         else:
             target = LazyTensor(factory_fn, meta_data=self._meta_data)
 
+        if isinstance(self, Parameter):
+            # hack isinstance check of parameter
+            target._is_param = True
+
         memo[id(self)] = target
         return target
 
@@ -408,6 +411,10 @@ class LazyTensor(torch.Tensor):
     def __hash__(self):
         return id(self)
 
+    def __rpow__(self, other):
+        dtype = torch.result_type(self, other)
+        return torch.tensor(other, dtype=dtype, device=self.device)**self
+
 
 class LazyInitContext:
     """Context manager for lazy initialization. Enables initializing the model without allocating real memory.
@@ -536,7 +543,7 @@ class LazyInitContext:
 
     @staticmethod
     def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
-        """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+        """Initialize all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
 
         Args:
             module (nn.Module): Target ``nn.Module``
@@ -553,7 +560,7 @@ class LazyInitContext:
                    device_mesh: DeviceMesh,
                    sharding_spec_dict: Dict[str, ShardingSpec],
                    verbose: bool = False) -> nn.Module:
-        """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+        """Distribute all ``Parameter`` from ``LazyTensor``. This function will modify the module in-place.
 
         Args:
             module (nn.Module): Target ``nn.Module``
@@ -625,3 +632,9 @@ def _is_int_tuple(args) -> bool:
         if not isinstance(x, int):
             return False
     return True
+
+
+def _copy_tensor(tensor: Tensor, requires_grad: bool) -> Tensor:
+    copied = tensor.data.clone()
+    copied.requires_grad = requires_grad
+    return copied
diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py
index db39a457b..07341ef73 100644
--- a/colossalai/shardformer/layer/embedding.py
+++ b/colossalai/shardformer/layer/embedding.py
@@ -9,8 +9,8 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch import Tensor
 from torch.distributed import ProcessGroup
-from torch.nn.parameter import Parameter
 
+from colossalai.lazy import LazyInitContext
 from colossalai.nn import init as init
 from colossalai.nn.layer.utils import divide
 from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
@@ -95,6 +95,7 @@ class Embedding1D(ParallelModule):
         r"""
         Build a 1D parallelized Embedding from a native nn.Embedding module.
         """
+        LazyInitContext.materialize(module)
         # get the attributes
         num_embedding = module.num_embeddings
         embedding_dim = module.embedding_dim
@@ -223,6 +224,7 @@ class VocabParallelEmbedding1D(ParallelModule):
         r"""
         Convert a native pytorch embedding module to a parallel module.
         """
+        LazyInitContext.materialize(module)
         # get the origin attributes
         num_embeddings = module.num_embeddings
         embedding_dim = module.embedding_dim
@@ -243,6 +245,7 @@ class VocabParallelEmbedding1D(ParallelModule):
                                                       process_group=process_group,
                                                       *args,
                                                       **kwargs)
+
         with torch.no_grad():
             # shard and slice the weight along the vocabulary(num_embeddings) dimension
             # the shape of the weight is (num_embeddings, embedding_dim)
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 26ba5883c..a8439f303 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -12,6 +12,7 @@ from torch import Tensor
 from torch.distributed import ProcessGroup
 from torch.nn.parameter import Parameter
 
+from colossalai.lazy import LazyInitContext
 from colossalai.nn import init as init
 from colossalai.nn.layer.utils import divide
 from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
@@ -106,6 +107,7 @@ class Linear1D_Col(ParallelModule):
         r"""
         Convert a native PyTorch linear layer to a parallelized linear layer.
         """
+        LazyInitContext.materialize(module)
         # get the attributes
         in_features = module.in_features
         out_features = module.out_features
@@ -242,6 +244,7 @@ class Linear1D_Row(ParallelModule):
         r"""
         Convert a native PyTorch linear layer to a parallelized linear layer.
         """
+        LazyInitContext.materialize(module)
         # get the attributes
         in_features = module.in_features
         out_features = module.out_features
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index b27307154..9bb7738c0 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -4,6 +4,8 @@
 import torch
 import torch.nn as nn
 
+from colossalai.lazy import LazyInitContext
+
 __all__ = ['FusedLayerNorm', 'FusedRMSNorm']
 
 FAST_LAYERNORM_SUPPORTED_SIZE = [
@@ -35,6 +37,7 @@ class FusedLayerNorm():
             raise ImportError(
                 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel')
 
+        LazyInitContext.materialize(module)
         # get the attributes of the module
         normalized_shape = module.normalized_shape
         eps = module.eps
@@ -84,6 +87,7 @@ class FusedRMSNorm():
                 'Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel'
             )
 
+        LazyInitContext.materialize(module)
         # to check if it is huggingface LlamaRMSNorm
         if module.__class__.__name__ == "LlamaRMSNorm":
             normalized_shape = module.weight.shape[0]
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 9d51670c6..c94d93069 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -12,6 +12,7 @@ from torch import Tensor
 from torch.distributed import ProcessGroup
 from torch.nn.parameter import Parameter
 
+from colossalai.lazy import LazyInitContext
 from colossalai.nn import init as init
 from colossalai.nn.layer.utils import divide
 from colossalai.tensor.d_tensor.api import (
@@ -231,6 +232,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
             process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication.
             n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight.
         """
+        LazyInitContext.materialize(module)
         # get the attributes
         in_features = module.weight.shape[0]
         out_features = module.weight.shape[1]
@@ -380,6 +382,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
         r"""
         Convert a native PyTorch linear layer to a parallelized linear layer.
         """
+        LazyInitContext.materialize(module)
         # get the attributes
         in_features = module.weight.shape[0]
         out_features = module.weight.shape[1]
@@ -428,9 +431,9 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
                     src_rank = dist.distributed_c10d._get_global_rank(self.process_group, 0)
 
                 origin_device = self.bias.device
-                self.bias = self.bias.cuda()
+                self.bias.data = self.bias.cuda()
                 dist.broadcast(self.bias, src=src_rank, group=self.process_group)
-                self.bias = self.bias.to(origin_device)
+                self.bias.data = self.bias.to(origin_device)
 
     def forward(self, input_: Tensor) -> Tensor:
         # Set up backprop all-reduce.
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index e18cb6ece..b80475e05 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -46,11 +46,12 @@ class BertPolicy(Policy):
         Reshape the Embedding layer to make the embedding dimension divisible by world_size
         """
         # TODO:
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+        if self.shard_config.enable_tensor_parallelism:
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
         return self.model
 
     def module_policy(self):
@@ -229,10 +230,11 @@ class BertForPreTrainingPolicy(BertPolicy):
         return []
 
     def postprocess(self):
-        binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-            setattr_(self.model, v, param)
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                setattr_(self.model, v, param)
         return self.model
 
 
@@ -269,10 +271,11 @@ class BertLMHeadModelPolicy(BertPolicy):
         return []
 
     def postprocess(self):
-        binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-            setattr_(self.model, v, param)
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                setattr_(self.model, v, param)
         return self.model
 
 
@@ -288,10 +291,11 @@ class BertForMaskedLMPolicy(BertPolicy):
         return module_policy
 
     def postprocess(self):
-        binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-            setattr_(self.model, v, param)
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                setattr_(self.model, v, param)
         return self.model
 
 
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index 8d6f07d4a..662ff5b49 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -17,11 +17,12 @@ class BloomPolicy(Policy):
         r"""
         Reshape the Embedding layer to make the embedding dimension divisible by world_size
         """
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+        if self.shard_config.enable_tensor_parallelism:
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
         return self.model
 
     def module_policy(self):
@@ -128,16 +129,13 @@ class BloomForCausalLMPolicy(BloomPolicy):
         return policy
 
     def postprocess(self):
-        binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
 
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-
-            if not isinstance(param, nn.Parameter):
-                param = nn.Parameter(param)
-
-            # tie weights
-            setattr_(self.model, v, param)
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                # tie weights
+                setattr_(self.model, v, param)
         return self.model
 
 
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 598f393c0..8f9d90e67 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -21,11 +21,12 @@ class GPT2Policy(Policy):
         r"""
         Reshape the Embedding layer to make the embedding dimension divisible by world_size
         """
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+        if self.shard_config.enable_tensor_parallelism:
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
         return self.model
 
     def module_policy(self):
@@ -142,10 +143,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
         return module_policy
 
     def postprocess(self):
-        binding_map = {"transformer.wte.weight": "lm_head.weight"}
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-            setattr_(self.model, v, param)
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"transformer.wte.weight": "lm_head.weight"}
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                setattr_(self.model, v, param)
         return self.model
 
 
@@ -172,10 +174,11 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
         return module_policy
 
     def postprocess(self):
-        binding_map = {"transformer.wte.weight": "lm_head.weight"}
-        for k, v in binding_map.items():
-            param = getattr_(self.model, k)
-            setattr_(self.model, v, param)
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"transformer.wte.weight": "lm_head.weight"}
+            for k, v in binding_map.items():
+                param = getattr_(self.model, k)
+                setattr_(self.model, v, param)
         return self.model
 
 
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 391938b27..b10e07560 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -15,13 +15,14 @@ class LlamaPolicy(Policy):
         pass
 
     def preprocess(self):
-        # Resize embedding
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
+        if self.shard_config.enable_tensor_parallelism:
+            # Resize embedding
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
 
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
 
         return self.model
 
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index c4c6cde01..1435805d2 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -19,11 +19,12 @@ class OPTPolicy(Policy):
         r"""
         Reshape the Embedding layer to make the embedding dimension divisible by world_size
         """
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+        if self.shard_config.enable_tensor_parallelism:
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
         return self.model
 
     def module_policy(self):
@@ -116,14 +117,15 @@ class OPTForCausalLMPolicy(OPTPolicy):
         return policy
 
     def postprocess(self):
-        binding_map = {
-            'model.decoder.embed_tokens': 'lm_head',
-        }
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {
+                'model.decoder.embed_tokens': 'lm_head',
+            }
 
-        for k, v in binding_map.items():
-            src_mod = getattr_(self.model, k)
-            dst_mod = getattr_(self.model, v)
-            dst_mod.weight = src_mod.weight
+            for k, v in binding_map.items():
+                src_mod = getattr_(self.model, k)
+                dst_mod = getattr_(self.model, v)
+                dst_mod.weight = src_mod.weight
 
         return self.model
 
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 6167e8161..37864885b 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -24,11 +24,12 @@ class T5BasePolicy(Policy):
         r"""
         Reshape the Embedding layer to make the embedding dimension divisible by world_size
         """
-        vocab_size = self.model.config.vocab_size
-        world_size = self.shard_config.tensor_parallel_size
-        if vocab_size % world_size != 0:
-            new_vocab_size = vocab_size + world_size - vocab_size % world_size
-            self.model.resize_token_embeddings(new_vocab_size)
+        if self.shard_config.enable_tensor_parallelism:
+            vocab_size = self.model.config.vocab_size
+            world_size = self.shard_config.tensor_parallel_size
+            if vocab_size % world_size != 0:
+                new_vocab_size = vocab_size + world_size - vocab_size % world_size
+                self.model.resize_token_embeddings(new_vocab_size)
         return self.model
 
     def module_policy(self):
@@ -164,11 +165,12 @@ class T5BasePolicy(Policy):
         return policy
 
     def postprocess(self):
-        binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
 
-        for k, v in binding_map:
-            mod = getattr_(self.model, k)
-            setattr_(self.model, v, mod)
+            for k, v in binding_map:
+                mod = getattr_(self.model, k)
+                setattr_(self.model, v, mod)
         return self.model
 
 
@@ -211,13 +213,13 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
 
     def postprocess(self):
         super().postprocess()
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = {"shared": "lm_head"}
 
-        binding_map = {"shared": "lm_head"}
-
-        for k, v in binding_map.items():
-            src_mod = getattr_(self.model, k)
-            dst_mod = getattr_(self.model, v)
-            dst_mod.weight = src_mod.weight
+            for k, v in binding_map.items():
+                src_mod = getattr_(self.model, k)
+                dst_mod = getattr_(self.model, v)
+                dst_mod.weight = src_mod.weight
 
         return self.model
 
@@ -239,11 +241,12 @@ class T5EncoderPolicy(T5BasePolicy):
         return base_policy
 
     def postprocess(self):
-        binding_map = [
-            ["shared", "encoder.embed_tokens"],
-        ]
+        if self.shard_config.enable_tensor_parallelism:
+            binding_map = [
+                ["shared", "encoder.embed_tokens"],
+            ]
 
-        for k, v in binding_map:
-            mod = getattr_(self.model, k)
-            setattr_(self.model, v, mod)
+            for k, v in binding_map:
+                mod = getattr_(self.model, k)
+                setattr_(self.model, v, mod)
         return self.model
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index ca2f46a18..56eb76973 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Union
 import torch.nn as nn
 from torch import Tensor
 
-from colossalai.lazy import LazyTensor
+from colossalai.lazy import LazyInitContext
 
 from .._utils import getattr_, setattr_
 from ..policies.auto_policy import get_autopolicy
@@ -192,10 +192,4 @@ class ModelSharder(object):
         r"""
         Materialize the model if lazy initialization is used
         """
-        for p in self.model.parameters():
-            if isinstance(p, LazyTensor):
-                p.materialize()
-
-        for b in self.model.buffers():
-            if isinstance(b, LazyTensor):
-                b.materialize()
+        LazyInitContext.materialize(self.model)
diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py
index 8a6aa42a4..99e494359 100644
--- a/tests/test_shardformer/test_layer/test_embedding.py
+++ b/tests/test_shardformer/test_layer/test_embedding.py
@@ -1,15 +1,22 @@
+from contextlib import nullcontext
+
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.lazy import LazyInitContext
 from colossalai.shardformer.layer import Embedding1D
-from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
 
-def check_embedding_1d():
-    embedding = nn.Embedding(32, 128).cuda()
+@parameterize('lazy_init', [False, True])
+def check_embedding_1d(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        embedding = nn.Embedding(32, 128).cuda()
     embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
 
     assert embedding_1d.weight.shape == torch.Size([32, 64])
diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py
index fc6d894c4..2cb6928ed 100644
--- a/tests/test_shardformer/test_layer/test_layernorm.py
+++ b/tests/test_shardformer/test_layer/test_layernorm.py
@@ -1,14 +1,21 @@
+from contextlib import nullcontext
+
 import torch
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.lazy import LazyInitContext
 from colossalai.shardformer.layer import FusedLayerNorm
-from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
 
-def check_layernorm():
-    norm = nn.LayerNorm(128, 0.00001).cuda()
+@parameterize('lazy_init', [False, True])
+def check_layernorm(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        norm = nn.LayerNorm(128, 0.00001).cuda()
     norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
 
     assert norm1d.weight.shape == torch.Size([128])
diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py
index da3bdc1d7..da3cd85ec 100644
--- a/tests/test_shardformer/test_layer/test_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_linear_1d.py
@@ -1,16 +1,23 @@
+from contextlib import nullcontext
+
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.lazy import LazyInitContext
 from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
 from colossalai.tensor.d_tensor import is_distributed_tensor
-from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
 
-def check_linear_1d_col():
-    linear = nn.Linear(32, 128).cuda()
+@parameterize('lazy_init', [False, True])
+def check_linear_1d_col(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        linear = nn.Linear(32, 128).cuda()
     linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
 
     # ensure that the parameters are distributed
@@ -50,8 +57,12 @@ def check_linear_1d_col():
     assert_close(x_for_unshard.grad, x_for_shard.grad)
 
 
-def check_linear_1d_row():
-    linear = nn.Linear(32, 128).cuda()
+@parameterize('lazy_init', [False, True])
+def check_linear_1d_row(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        linear = nn.Linear(32, 128).cuda()
     linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
 
     assert linear_row.weight.shape == torch.Size([128, 16])
@@ -83,9 +94,13 @@ def check_linear_1d_row():
     assert_close(x_for_unshard.grad, x_for_shard.grad)
 
 
-def check_linear_col_plus_row():
-    linear_1 = nn.Linear(32, 128).cuda()
-    linear_2 = nn.Linear(128, 32).cuda()
+@parameterize('lazy_init', [False, True])
+def check_linear_col_plus_row(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        linear_1 = nn.Linear(32, 128).cuda()
+        linear_2 = nn.Linear(128, 32).cuda()
     linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
     linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
 
diff --git a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
index 681c4f6dd..186b1e821 100644
--- a/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
+++ b/tests/test_shardformer/test_layer/test_qkv_fused_linear_1d.py
@@ -1,12 +1,15 @@
+from contextlib import nullcontext
+
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
+from colossalai.lazy import LazyInitContext
 from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
 from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
-from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
 
 # This code is copied from https://github.com/huggingface/transformers
@@ -50,8 +53,12 @@ def rearrange(tensor: torch.Tensor, dim: int):
     return rearanged_tensor
 
 
-def check_linear_conv_1d_col():
-    linear = Conv1D(192, 48).cuda()
+@parameterize('lazy_init', [False, True])
+def check_linear_conv_1d_col(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        linear = Conv1D(192, 48).cuda()
     linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
                                                                    process_group=None,
                                                                    gather_output=True,
@@ -80,8 +87,12 @@ def check_linear_conv_1d_col():
     assert_close(target_grad, linear_conv_col.weight.grad)
 
 
-def check_linear_conv_1d_row():
-    linear = Conv1D(192, 48).cuda()
+@parameterize('lazy_init', [False, True])
+def check_linear_conv_1d_row(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        linear = Conv1D(192, 48).cuda()
     linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
 
     assert linear.weight.shape == torch.Size([48, 192])
diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
index 8991d9b30..bf5803496 100644
--- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
+++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py
@@ -1,15 +1,23 @@
+from contextlib import nullcontext
+
 import torch
 import torch.distributed as dist
 import torch.nn as nn
 from torch.testing import assert_close
 
 import colossalai
-from colossalai.shardformer.layer import VocabParallelEmbedding1D
+from colossalai.lazy import LazyInitContext
+from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 
 
-def check_vocab_embedding_1d():
-    embedding = nn.Embedding(128, 32).to('cuda')
+@parameterize('lazy_init', [False, True])
+def check_vocab_embedding_1d(lazy_init: bool):
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
+    with ctx:
+        embedding = nn.Embedding(128, 32).to('cuda')
     dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
 
     assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py
index e03014f3f..f83cfcd49 100644
--- a/tests/test_shardformer/test_model/_utils.py
+++ b/tests/test_shardformer/test_model/_utils.py
@@ -1,19 +1,24 @@
 import copy
+from contextlib import nullcontext
 
+from colossalai.lazy import LazyInitContext
 from colossalai.shardformer import ShardConfig, ShardFormer
 
 
-def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True):
-    # create new model
-    org_model = model_fn().cuda()
-
+def build_model(model_fn, enable_fused_normalization=True, enable_tensor_parallelism=True, use_lazy_init: bool = False):
+    ctx = LazyInitContext() if use_lazy_init else nullcontext()
+    with ctx:
+        # create new model
+        org_model = model_fn()
+        model_copy = copy.deepcopy(org_model)
+    if use_lazy_init:
+        ctx.materialize(org_model)
     # shard model
     shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
                                enable_tensor_parallelism=enable_tensor_parallelism)
-    model_copy = copy.deepcopy(org_model)
     shard_former = ShardFormer(shard_config=shard_config)
     sharded_model, shared_params = shard_former.optimize(model_copy)
-    return org_model, sharded_model.cuda()
+    return org_model.cuda(), sharded_model.cuda()
 
 
 def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index 1afedb707..7f179acd7 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -67,12 +67,14 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
                           atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
 
 
-@parameterize('enable_fused_normalization', [True, False])
-@parameterize('enable_tensor_parallelism', [True, False])
-def run_bert_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('enable_fused_normalization', [False, True])
+@parameterize('enable_tensor_parallelism', [False, True])
+@parameterize('use_lazy_init', [False, True])
+def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
 
     torch.cuda.empty_cache()
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index a33896522..e18168292 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 
 @parameterize('enable_fused_normalization', [True, False])
 @parameterize('enable_tensor_parallelism', [True, False])
-def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('use_lazy_init', [False, True])
+def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
     torch.cuda.empty_cache()
 
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index ee7737687..96c4b90a8 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 
 @parameterize('enable_fused_normalization', [True, False])
 @parameterize('enable_tensor_parallelism', [True, False])
-def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('use_lazy_init', [False, True])
+def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
     torch.cuda.empty_cache()
 
diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py
index 74b5fdd18..4d63a4348 100644
--- a/tests/test_shardformer/test_model/test_shard_llama.py
+++ b/tests/test_shardformer/test_model/test_shard_llama.py
@@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 
 @parameterize('enable_fused_normalization', [True, False])
 @parameterize('enable_tensor_parallelism', [True, False])
-def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('use_lazy_init', [False, True])
+def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
     torch.cuda.empty_cache()
 
diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py
index 25bccb13b..c008596fe 100644
--- a/tests/test_shardformer/test_model/test_shard_opt.py
+++ b/tests/test_shardformer/test_model/test_shard_opt.py
@@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 
 @parameterize('enable_fused_normalization', [True, False])
 @parameterize('enable_tensor_parallelism', [True, False])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('use_lazy_init', [False, True])
+def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
     torch.cuda.empty_cache()
 
diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py
index 0762dc09e..ccd7d3787 100644
--- a/tests/test_shardformer/test_model/test_shard_t5.py
+++ b/tests/test_shardformer/test_model/test_shard_t5.py
@@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
 
 @parameterize('enable_fused_normalization', [True, False])
 @parameterize('enable_tensor_parallelism', [True, False])
-def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
+@parameterize('use_lazy_init', [False, True])
+def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
     sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
-        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
+        org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
+                                               use_lazy_init)
         check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
     torch.cuda.empty_cache()
 
diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py
index f29c8d6f6..2b6933246 100644
--- a/tests/test_shardformer/test_with_torch_ddp.py
+++ b/tests/test_shardformer/test_with_torch_ddp.py
@@ -1,3 +1,5 @@
+from contextlib import nullcontext
+
 import pytest
 import torch
 import torch.distributed as dist
@@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 
 import colossalai
 from colossalai.cluster import DistCoordinator
+from colossalai.lazy import LazyInitContext
 from colossalai.logging import disable_existing_loggers
 from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
 from tests.kit.model_zoo import model_zoo
 
 
-def check_shardformer_with_ddp(rank, world_size, port):
-    disable_existing_loggers()
-    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+@parameterize('lazy_init', [True, False])
+def check_shardformer_with_ddp(lazy_init: bool):
 
     sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
 
@@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
     shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
     shardformer = ShardFormer(shard_config=shard_config)
 
+    ctx = LazyInitContext() if lazy_init else nullcontext()
+
     for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
         # create and shard model
-        model = model_fn().cuda()
+        with ctx:
+            model = model_fn().cuda()
         sharded_model, _ = shardformer.optimize(model)
 
         # add ddp
@@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
         torch.cuda.empty_cache()
 
 
+def run_dist(rank, world_size, port):
+    disable_existing_loggers()
+    colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    check_shardformer_with_ddp()
+
+
 @pytest.mark.dist
 @rerun_if_address_is_in_use()
 @clear_cache_before_run()
 def test_gpt2():
-    spawn(check_shardformer_with_ddp, 4)
+    spawn(run_dist, 4)
 
 
 if __name__ == "__main__":
     test_gpt2()
-    test_gpt2()