From b2483c8e31c7291f1d083d7c9c59bd79203c3c3b Mon Sep 17 00:00:00 2001
From: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Date: Mon, 12 Aug 2024 18:17:05 +0800
Subject: [PATCH] [fp8] support hybrid parallel plugin (#5982)

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* support fp8 comm for qwen2 model

* fp8

* fix

* bert and bloom

* chatglm and command

* gpt2,gptj,bert, falcon,blip2

* mistral,opy,sam,t5,vit,whisper

* fix

* fix

* fix
---
 colossalai/shardformer/layer/embedding.py     |  6 +-
 colossalai/shardformer/layer/linear.py        |  2 +
 .../shardformer/layer/qkv_fused_linear.py     |  6 +-
 colossalai/shardformer/modeling/bert.py       | 30 ++++++--
 colossalai/shardformer/modeling/bloom.py      | 20 ++++-
 colossalai/shardformer/modeling/chatglm2.py   | 37 +++++++++-
 colossalai/shardformer/modeling/command.py    | 30 ++++++--
 colossalai/shardformer/modeling/gpt2.py       |  2 +
 colossalai/shardformer/modeling/gptj.py       |  4 +
 colossalai/shardformer/modeling/llama.py      | 16 +++-
 colossalai/shardformer/modeling/qwen2.py      | 30 ++++++--
 colossalai/shardformer/policies/bert.py       | 22 +++++-
 colossalai/shardformer/policies/blip2.py      | 70 +++++++++++++++++-
 colossalai/shardformer/policies/bloom.py      | 40 ++++++++--
 colossalai/shardformer/policies/chatglm2.py   | 16 +++-
 colossalai/shardformer/policies/command.py    | 24 ++++--
 colossalai/shardformer/policies/falcon.py     |  9 ++-
 colossalai/shardformer/policies/gpt2.py       | 11 ++-
 colossalai/shardformer/policies/gptj.py       | 22 +++++-
 colossalai/shardformer/policies/llama.py      | 10 ++-
 colossalai/shardformer/policies/mistral.py    | 39 +++++++++-
 colossalai/shardformer/policies/opt.py        | 22 +++++-
 colossalai/shardformer/policies/qwen2.py      | 33 ++++++---
 colossalai/shardformer/policies/sam.py        | 64 ++++++++++++++++
 colossalai/shardformer/policies/t5.py         | 73 +++++++++++++++++--
 colossalai/shardformer/policies/vit.py        | 20 ++++-
 colossalai/shardformer/policies/whisper.py    | 58 ++++++++++++++-
 27 files changed, 633 insertions(+), 83 deletions(-)

diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py
index 186063503..18efb0ec5 100644
--- a/colossalai/shardformer/layer/embedding.py
+++ b/colossalai/shardformer/layer/embedding.py
@@ -68,6 +68,7 @@ class Embedding1D(ParallelModule):
         gather_output: bool = True,
         weight: Optional[nn.Parameter] = None,
         weight_initializer: Callable = init.normal_(),
+        fp8_communication: bool = False,
         *args,
         **kwargs,
     ):
@@ -81,6 +82,7 @@ class Embedding1D(ParallelModule):
         self.embed_args = args
         self.embed_kwargs = kwargs
         self.gather_output = gather_output
+        self.fp8_communication = fp8_communication
 
         # offset the seed with randomizer index and rank
         seed = torch.random.initial_seed()
@@ -155,7 +157,9 @@ class Embedding1D(ParallelModule):
     def forward(self, input_: Tensor) -> Tensor:
         output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
         if self.gather_output:
-            output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
+            output = gather_forward_split_backward(
+                output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
+            )
             return output
         else:
             return output_parallel
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index 38a6ef1a1..4e5ebef0d 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -572,6 +572,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
         weight: Optional[Parameter] = None,
         bias_: Optional[Parameter] = None,
         make_vocab_size_divisible_by: int = 64,
+        fp8_communication: bool = False,
         **kwargs,
     ):
         # create weight and bias
@@ -602,6 +603,7 @@ class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
             **kwargs,
             new_num_embeddings=new_out_features,
             old_num_embeddings=out_features,
+            fp8_communication=fp8_communication,
         )
         # get the length of valid embeddings
         tp_rank = dist.get_rank(process_group)
diff --git a/colossalai/shardformer/layer/qkv_fused_linear.py b/colossalai/shardformer/layer/qkv_fused_linear.py
index 93a7eb231..561867993 100644
--- a/colossalai/shardformer/layer/qkv_fused_linear.py
+++ b/colossalai/shardformer/layer/qkv_fused_linear.py
@@ -627,6 +627,7 @@ class FusedLinear1D_Col(ParallelModule):
         bias_: Optional[Parameter] = None,
         weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
         bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
+        fp8_communication: bool = False,
     ):
         super().__init__()
         # Keep input parameters
@@ -638,6 +639,7 @@ class FusedLinear1D_Col(ParallelModule):
         self.n_fused = n_fused
         self.process_group = process_group
         self.async_communication = async_communication
+        self.fp8_communication = fp8_communication
 
         if skip_bias_add and not bias:
             raise ValueError("cannot skip bias addition if bias is None")
@@ -767,7 +769,9 @@ class FusedLinear1D_Col(ParallelModule):
 
         if self.gather_output:
             # All-gather across the partitions.
-            output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
+            output = gather_forward_split_backward(
+                output_parallel, dim=-1, process_group=self.process_group, fp8_communication=self.fp8_communication
+            )
         else:
             output = output_parallel
 
diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py
index 7710b56e7..580f3618c 100644
--- a/colossalai/shardformer/modeling/bert.py
+++ b/colossalai/shardformer/modeling/bert.py
@@ -187,11 +187,17 @@ class BertPipelineForwards:
         if shard_config is not None and shard_config.enable_sequence_parallelism:
             if shard_config.sequence_parallelism_mode == "split_gather":
                 hidden_states = split_forward_gather_backward(
-                    hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                    hidden_states,
+                    dim=1,
+                    process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
                 if encoder_hidden_states is not None:
                     encoder_hidden_states = split_forward_gather_backward(
-                        encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                        encoder_hidden_states,
+                        dim=1,
+                        process_group=shard_config.tensor_parallel_process_group,
+                        fp8_communication=shard_config.fp8_communication,
                     )
 
         for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
@@ -242,7 +248,10 @@ class BertPipelineForwards:
         if shard_config is not None and shard_config.enable_sequence_parallelism:
             if shard_config.sequence_parallelism_mode == "split_gather":
                 hidden_states = gather_forward_split_backward(
-                    hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                    hidden_states,
+                    dim=1,
+                    process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         if output_hidden_states:
@@ -1135,11 +1144,17 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
         # split the input tensor along sequence dimension
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         embedding_output = split_forward_gather_backward(
-            embedding_output, dim=1, process_group=shard_config.tensor_parallel_process_group
+            embedding_output,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
         if encoder_hidden_states is not None:
             encoder_hidden_states = split_forward_gather_backward(
-                encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                encoder_hidden_states,
+                dim=1,
+                process_group=shard_config.tensor_parallel_process_group,
+                fp8_communication=shard_config.fp8_communication,
             )
 
         encoder_outputs = self.encoder(
@@ -1159,7 +1174,10 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
 
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         sequence_output = gather_forward_split_backward(
-            sequence_output, dim=1, process_group=shard_config.tensor_parallel_process_group
+            sequence_output,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
 
         pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 26ffef6c5..f8fd4665f 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -221,7 +221,10 @@ class BloomPipelineForwards:
         if shard_config and shard_config.enable_sequence_parallelism:
             if shard_config.sequence_parallelism_mode == "split_gather":
                 hidden_states = split_forward_gather_backward(
-                    hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                    hidden_states,
+                    dim=1,
+                    process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         start_idx, end_idx = stage_index[0], stage_index[1]
@@ -264,7 +267,10 @@ class BloomPipelineForwards:
         if shard_config and shard_config.enable_sequence_parallelism:
             if shard_config.sequence_parallelism_mode == "split_gather":
                 hidden_states = gather_forward_split_backward(
-                    hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+                    hidden_states,
+                    dim=1,
+                    process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         if stage_manager.is_last_stage():
@@ -922,7 +928,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
         # split the input tensor along sequence dimension
         # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
         hidden_states = split_forward_gather_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
 
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -960,7 +969,10 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
 
         # When sequence parallelism done, gather the output tensor in forward and split it in backward
         hidden_states = gather_forward_split_backward(
-            hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group
+            hidden_states,
+            dim=1,
+            process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
         # Add last hidden state
         hidden_states = self.ln_f(hidden_states)
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index 34d900d8d..5fd7b6461 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -206,6 +206,7 @@ class ChatGLMPipelineForwards:
                     hidden_states,
                     dim=0,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = split_forward_gather_backward(
@@ -213,6 +214,7 @@ class ChatGLMPipelineForwards:
                     dim=0,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=1 / shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
         for idx in range(start_idx, end_idx):
             layer = self.encoder._get_layer(idx)
@@ -245,6 +247,7 @@ class ChatGLMPipelineForwards:
                     hidden_states,
                     dim=0,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = gather_forward_split_backward(
@@ -252,6 +255,7 @@ class ChatGLMPipelineForwards:
                     dim=0,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
         if output_hidden_states:
             all_hidden_states = all_hidden_states + (hidden_states,)
@@ -414,6 +418,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
                 inputs_embeds,
                 dim=0,
                 process_group=sp_group,
+                fp8_communication=shard_config.fp8_communication,
             )
         elif sp_mode == "all_to_all":
             inputs_embeds = split_forward_gather_backward(
@@ -421,6 +426,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
                 dim=0,
                 process_group=sp_group,
                 grad_scale=1 / sp_size,
+                fp8_communication=shard_config.fp8_communication,
             )
         hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
             inputs_embeds,
@@ -436,6 +442,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
                 hidden_states,
                 dim=0,
                 process_group=shard_config.tensor_parallel_process_group,
+                fp8_communication=shard_config.fp8_communication,
             )
         elif sp_mode == "all_to_all":
             hidden_states = gather_forward_split_backward(
@@ -443,6 +450,7 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig, sp_mode,
                 dim=0,
                 process_group=sp_group,
                 grad_scale=sp_size,
+                fp8_communication=shard_config.fp8_communication,
             )
 
         if not return_dict:
@@ -532,9 +540,24 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
             key_layer = key_layer.reshape(sq, bs, -1)
             value_layer = value_layer.reshape(sq, bs, -1)
 
-            query_layer = all_to_all_comm(query_layer, sp_group, gather_dim=0)
-            key_layer = all_to_all_comm(key_layer, sp_group, gather_dim=0)
-            value_layer = all_to_all_comm(value_layer, sp_group, gather_dim=0)
+            query_layer = all_to_all_comm(
+                query_layer,
+                sp_group,
+                gather_dim=0,
+                fp8_communication=shard_config.fp8_communication,
+            )
+            key_layer = all_to_all_comm(
+                key_layer,
+                sp_group,
+                gather_dim=0,
+                fp8_communication=shard_config.fp8_communication,
+            )
+            value_layer = all_to_all_comm(
+                value_layer,
+                sp_group,
+                gather_dim=0,
+                fp8_communication=shard_config.fp8_communication,
+            )
 
             query_layer = query_layer.view(
                 sq * sp_size,
@@ -610,7 +633,13 @@ def get_chatglm_sequence_parallel_attention_forward(shard_config: ShardConfig, s
 
         context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
         if sp_mode == "all_to_all":
-            context_layer = all_to_all_comm(context_layer, sp_group, gather_dim=2, scatter_dim=0)
+            context_layer = all_to_all_comm(
+                context_layer,
+                sp_group,
+                gather_dim=2,
+                scatter_dim=0,
+                fp8_communication=shard_config.fp8_communication,
+            )
 
         # =================
         # Output. [sq, b, h]
diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py
index 5b36fc7db..1ece0c118 100644
--- a/colossalai/shardformer/modeling/command.py
+++ b/colossalai/shardformer/modeling/command.py
@@ -140,6 +140,7 @@ class CommandPipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = split_forward_gather_backward(
@@ -147,6 +148,7 @@ class CommandPipelineForwards:
                     dim=1,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=1 / shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         # decoder layers
@@ -211,6 +213,7 @@ class CommandPipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = gather_forward_split_backward(
@@ -218,6 +221,7 @@ class CommandPipelineForwards:
                     dim=1,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         # add hidden states from the last decoder layer
@@ -382,9 +386,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
 
         # sp: all-to-all comminucation when introducing sequence parallel
         if sp_mode == "all_to_all":
-            query_states = all_to_all_comm(query_states, sp_group)
-            key_states = all_to_all_comm(key_states, sp_group)
-            value_states = all_to_all_comm(value_states, sp_group)
+            query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
+            key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
+            value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
             bsz, q_len, _ = query_states.size()
 
         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -446,7 +450,9 @@ def get_command_flash_attention_forward(shard_config, sp_mode=None, sp_size=None
         # sp: all-to-all comminucation when introducing sequence parallel
         if sp_mode == "all_to_all":
             attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
-            attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
+            attn_output = all_to_all_comm(
+                attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
+            )
         else:
             attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
@@ -526,9 +532,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
             attention_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
 
         if sp_mode in ["ring", "split_gather"]:
-            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
+            inputs_embeds = split_forward_gather_backward(
+                inputs_embeds, 1, sp_group, fp8_communication=shard_config.fp8_communication
+            )
         elif sp_mode == "all_to_all":
-            inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
+            inputs_embeds = split_forward_gather_backward(
+                inputs_embeds, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
+            )
         hidden_states = inputs_embeds
 
         # decoder layers
@@ -573,9 +583,13 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
         hidden_states = self.norm(hidden_states)
 
         if sp_mode == "ring" or sp_mode == "split_gather":
-            hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+            hidden_states = gather_forward_split_backward(
+                hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+            )
         elif sp_mode == "all_to_all":
-            hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+            hidden_states = gather_forward_split_backward(
+                hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
+            )
 
         # add hidden states from the last decoder layer
         if output_hidden_states:
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 0dbf0ca5a..97544e110 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -221,6 +221,7 @@ class GPT2PipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         # Going through held blocks.
@@ -276,6 +277,7 @@ class GPT2PipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         if stage_manager.is_last_stage():
diff --git a/colossalai/shardformer/modeling/gptj.py b/colossalai/shardformer/modeling/gptj.py
index facd2fcaf..51b228712 100644
--- a/colossalai/shardformer/modeling/gptj.py
+++ b/colossalai/shardformer/modeling/gptj.py
@@ -185,6 +185,7 @@ class GPTJPipelineForwards:
                 hidden_states,
                 dim=1,
                 process_group=shard_config.tensor_parallel_process_group,
+                fp8_communication=shard_config.fp8_communication,
             )
 
         # Going through held blocks.
@@ -236,6 +237,7 @@ class GPTJPipelineForwards:
                 hidden_states,
                 dim=1,
                 process_group=shard_config.tensor_parallel_process_group,
+                fp8_communication=shard_config.fp8_communication,
             )
 
         if stage_manager.is_last_stage():
@@ -915,6 +917,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
             hidden_states,
             dim=1,
             process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
 
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@@ -978,6 +981,7 @@ def gptj_sequence_parallel_forward_fn(shard_config: ShardConfig):
             hidden_states,
             dim=1,
             process_group=shard_config.tensor_parallel_process_group,
+            fp8_communication=shard_config.fp8_communication,
         )
 
         hidden_states = self.ln_f(hidden_states)
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 693f6584f..99b31745b 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -143,9 +143,13 @@ class LlamaPipelineForwards:
         # Support SP + PP
         if stage_manager.is_first_stage():
             if sp_mode in ["ring", "split_gather"]:
-                hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
+                hidden_states = split_forward_gather_backward(
+                    hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+                )
             elif sp_mode == "all_to_all":
-                hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
+                hidden_states = split_forward_gather_backward(
+                    hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
+                )
 
         if self.gradient_checkpointing and self.training and use_cache:
             if use_cache:
@@ -210,9 +214,13 @@ class LlamaPipelineForwards:
         if stage_manager.is_last_stage():
             hidden_states = self.norm(hidden_states)
             if sp_mode == "ring" or sp_mode == "split_gather":
-                hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+                hidden_states = gather_forward_split_backward(
+                    hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+                )
             elif sp_mode == "all_to_all":
-                hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+                hidden_states = gather_forward_split_backward(
+                    hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
+                )
 
         # add hidden states from the last decoder layer
         if output_hidden_states:
diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py
index 538e96c32..a61360d17 100644
--- a/colossalai/shardformer/modeling/qwen2.py
+++ b/colossalai/shardformer/modeling/qwen2.py
@@ -175,6 +175,7 @@ class Qwen2PipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = split_forward_gather_backward(
@@ -182,6 +183,7 @@ class Qwen2PipelineForwards:
                     dim=1,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=1 / shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
 
         # decoder layers
@@ -246,6 +248,7 @@ class Qwen2PipelineForwards:
                     hidden_states,
                     dim=1,
                     process_group=shard_config.tensor_parallel_process_group,
+                    fp8_communication=shard_config.fp8_communication,
                 )
             elif shard_config.sequence_parallelism_mode == "all_to_all":
                 hidden_states = gather_forward_split_backward(
@@ -253,6 +256,7 @@ class Qwen2PipelineForwards:
                     dim=1,
                     process_group=shard_config.sequence_parallel_process_group,
                     grad_scale=shard_config.sequence_parallel_size,
+                    fp8_communication=shard_config.fp8_communication,
                 )
         # add hidden states from the last decoder layer
         if output_hidden_states:
@@ -516,9 +520,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
         value_states = self.v_proj(hidden_states)
         # sp: all-to-all comminucation when introducing sequence parallel
         if sp_mode == "all_to_all":
-            query_states = all_to_all_comm(query_states, sp_group)
-            key_states = all_to_all_comm(key_states, sp_group)
-            value_states = all_to_all_comm(value_states, sp_group)
+            query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
+            key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
+            value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
             bsz, q_len, _ = query_states.size()
 
         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
@@ -604,7 +608,9 @@ def get_qwen2_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
         attn_output = attn_output.transpose(1, 2).contiguous()
         if sp_mode == "all_to_all":
             attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
-            attn_output = all_to_all_comm(attn_output, sp_group, scatter_dim=1, gather_dim=2)
+            attn_output = all_to_all_comm(
+                attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
+            )
         else:
             attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
@@ -702,9 +708,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
         next_decoder_cache = None
 
         if sp_mode in ["ring", "split_gather"]:
-            hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
+            hidden_states = split_forward_gather_backward(
+                hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+            )
         elif sp_mode == "all_to_all":
-            hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
+            hidden_states = split_forward_gather_backward(
+                hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
+            )
 
         for decoder_layer in self.layers:
             if output_hidden_states:
@@ -741,9 +751,13 @@ def get_qwen2_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=No
         hidden_states = self.norm(hidden_states)
 
         if sp_mode == "ring" or sp_mode == "split_gather":
-            hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group)
+            hidden_states = gather_forward_split_backward(
+                hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+            )
         elif sp_mode == "all_to_all":
-            hidden_states = gather_forward_split_backward(hidden_states, 1, sp_group, grad_scale=sp_size)
+            hidden_states = gather_forward_split_backward(
+                hidden_states, 1, sp_group, grad_scale=sp_size, fp8_communication=shard_config.fp8_communication
+            )
 
         # add hidden states from the last decoder layer
         if output_hidden_states:
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index b84a372a5..4c33e14bc 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -98,6 +98,7 @@ class BertPolicy(Policy):
                         kwargs={
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -106,6 +107,7 @@ class BertPolicy(Policy):
                         kwargs={
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -114,6 +116,7 @@ class BertPolicy(Policy):
                         kwargs={
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -123,7 +126,10 @@ class BertPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attention.output.dense",
                         target_module=col_nn.Linear1D_Row,
-                        kwargs={"seq_parallel_mode": sp_mode},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.output.dropout",
@@ -136,12 +142,16 @@ class BertPolicy(Policy):
                             "seq_parallel_mode": sp_mode,
                             "overlap": overlap,
                             "skip_bias_add": self.enable_bias_gelu_fused,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output.dense",
                         target_module=col_nn.Linear1D_Row,
-                        kwargs={"seq_parallel_mode": sp_mode},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output.dropout",
@@ -180,6 +190,13 @@ class BertPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="word_embeddings",
                         target_module=embedding_cls,
+                        kwargs=(
+                            {
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {}
+                        ),
                     )
                 ],
                 policy=policy,
@@ -249,6 +266,7 @@ class BertPolicy(Policy):
                     kwargs={
                         "gather_output": True,
                         "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                        "fp8_communication": self.shard_config.fp8_communication,
                     },
                 ),
                 policy=base_policy,
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index 32d4edadb..da798f6a0 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -72,20 +72,30 @@ class BlipPolicy(Policy):
                         target_module=col_nn.FusedLinear1D_Col,
                         kwargs={
                             "n_fused": 3,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.projection",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.fc1",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"skip_bias_add": self.enable_bias_gelu_fused},
+                        kwargs={
+                            "skip_bias_add": self.enable_bias_gelu_fused,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.fc2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -114,14 +124,23 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attention.attention.query",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.key",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.value",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.dropout",
@@ -130,6 +149,9 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attention.output.dense",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.output.dropout",
@@ -138,14 +160,23 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="crossattention.attention.query",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="crossattention.attention.key",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="crossattention.attention.value",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="crossattention.attention.dropout",
@@ -154,6 +185,9 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="crossattention.output.dense",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="crossattention.output.dropout",
@@ -162,10 +196,16 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="intermediate_query.dense",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output_query.dense",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output_query.dropout",
@@ -185,26 +225,44 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc1",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -225,7 +283,14 @@ class BlipPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="model.decoder.embed_tokens",
                         target_module=embedding_cls,
-                        kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                        kwargs=(
+                            {
+                                "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                        ),
                     ),
                 ],
                 policy=policy,
@@ -241,6 +306,7 @@ class BlipPolicy(Policy):
                         kwargs={
                             "gather_output": True,
                             "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                 ],
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index d80adb84a..a43ac02d0 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -76,12 +76,19 @@ class BloomPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attention.query_key_value",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attention.dense",
                         target_module=col_nn.Linear1D_Row,
-                        kwargs={"seq_parallel_mode": sp_mode},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attention.attention_dropout",
@@ -90,12 +97,19 @@ class BloomPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="mlp.dense_h_to_4h",
                         target_module=col_nn.Linear1D_Col,
-                        kwargs={"seq_parallel_mode": sp_mode, "overlap": overlap},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.dense_4h_to_h",
                         target_module=col_nn.Linear1D_Row,
-                        kwargs={"seq_parallel_mode": sp_mode},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -115,7 +129,14 @@ class BloomPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="word_embeddings",
                         target_module=embedding_cls,
-                        kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                        kwargs=(
+                            {
+                                "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                        ),
                     ),
                 ],
                 policy=policy,
@@ -279,6 +300,7 @@ class BloomForCausalLMPolicy(BloomPolicy):
                     kwargs=dict(
                         gather_output=not self.shard_config.parallel_output,
                         make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
+                        fp8_communication=self.shard_config.fp8_communication,
                     ),
                 ),
                 policy=policy,
@@ -337,7 +359,9 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
         if self.shard_config.enable_tensor_parallelism:
             self.append_or_create_submodule_replacement(
                 description=SubModuleReplacementDescription(
-                    suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
+                    suffix="score",
+                    target_module=col_nn.Linear1D_Col,
+                    kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
                 ),
                 policy=policy,
                 target_key=BloomForSequenceClassification,
@@ -374,7 +398,9 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
             self.append_or_create_submodule_replacement(
                 description=[
                     SubModuleReplacementDescription(
-                        suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
+                        suffix="classifier",
+                        target_module=col_nn.Linear1D_Col,
+                        kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="dropout",
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
index 3877bdac3..16c13085a 100644
--- a/colossalai/shardformer/policies/chatglm2.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -128,12 +128,17 @@ class ChatGLMPolicy(Policy):
                             "seq_parallel_mode": sp_mode,
                             "seq_parallel_dim": 0,
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attention.dense",
                         target_module=col_nn.Linear1D_Row,
-                        kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0},
+                        kwargs={
+                            "seq_parallel_mode": sp_mode,
+                            "seq_parallel_dim": 0,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attention.core_attention.attention_dropout",
@@ -148,7 +153,14 @@ class ChatGLMPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="embedding.word_embeddings",
                         target_module=embedding_cls,
-                        kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                        kwargs=(
+                            {
+                                "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                        ),
                     ),
                 ],
                 policy=policy,
diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py
index a9b915d10..d0903b11c 100644
--- a/colossalai/shardformer/policies/command.py
+++ b/colossalai/shardformer/policies/command.py
@@ -126,37 +126,37 @@ class CommandPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.o_proj",
                         target_module=Linear1D_Row,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.gate_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.up_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.down_proj",
                         target_module=Linear1D_Row,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                 ],
             )
@@ -166,7 +166,14 @@ class CommandPolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=CohereModel,
@@ -303,6 +310,7 @@ class CommandForCausalLMPolicy(CommandPolicy):
                             kwargs={
                                 "gather_output": not self.shard_config.parallel_output,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ],
diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py
index e5c167337..e20fb1568 100644
--- a/colossalai/shardformer/policies/falcon.py
+++ b/colossalai/shardformer/policies/falcon.py
@@ -105,7 +105,14 @@ class FalconPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="word_embeddings",
                         target_module=embedding_cls,
-                        kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                        kwargs=(
+                            {
+                                "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                        ),
                     ),
                 ],
                 policy=policy,
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index bb6269737..0a1949d85 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -162,7 +162,14 @@ class GPT2Policy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="wte",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=GPT2Model,
@@ -332,6 +339,7 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
                             kwargs={
                                 "gather_output": False,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ],
@@ -402,6 +410,7 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
                             kwargs={
                                 "gather_output": True,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ]
diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py
index c394d911e..6f0c8803c 100644
--- a/colossalai/shardformer/policies/gptj.py
+++ b/colossalai/shardformer/policies/gptj.py
@@ -77,6 +77,7 @@ class GPTJPolicy(Policy):
                         target_module=col_nn.Linear1D_Col,
                         kwargs={
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -84,6 +85,7 @@ class GPTJPolicy(Policy):
                         target_module=col_nn.Linear1D_Col,
                         kwargs={
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
@@ -91,19 +93,29 @@ class GPTJPolicy(Policy):
                         target_module=col_nn.Linear1D_Col,
                         kwargs={
                             "overlap": overlap,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.fc_in",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.fc_out",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.attn_dropout",
@@ -125,7 +137,14 @@ class GPTJPolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="wte",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=GPTJModel,
@@ -264,6 +283,7 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
                             kwargs={
                                 "gather_output": True,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ]
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 6f8404219..23d2d3913 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -166,7 +166,14 @@ class LlamaPolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=LlamaModel,
@@ -308,6 +315,7 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
                             kwargs={
                                 "gather_output": not self.shard_config.parallel_output,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ],
diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py
index c5a0277a5..72a5158e5 100644
--- a/colossalai/shardformer/policies/mistral.py
+++ b/colossalai/shardformer/policies/mistral.py
@@ -88,30 +88,51 @@ class MistralPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.o_proj",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.gate_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.up_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.down_proj",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -121,7 +142,14 @@ class MistralPolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=MistralModel,
@@ -281,6 +309,7 @@ class MistralForCausalLMPolicy(MistralPolicy):
                             kwargs={
                                 "gather_output": not self.shard_config.parallel_output,
                                 "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
                             },
                         )
                     ]
@@ -297,7 +326,9 @@ class MistralForCausalLMPolicy(MistralPolicy):
                         SubModuleReplacementDescription(
                             suffix="lm_head",
                             target_module=PaddingLMHead,
-                            kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
+                            kwargs=dict(
+                                make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
+                            ),
                         )
                     ]
                 )
@@ -350,7 +381,9 @@ class MistralForSequenceClassificationPolicy(MistralPolicy):
                 MistralForSequenceClassification: ModulePolicyDescription(
                     sub_module_replacement=[
                         SubModuleReplacementDescription(
-                            suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                            suffix="score",
+                            target_module=Linear1D_Col,
+                            kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
                         )
                     ]
                 )
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index 524d2b8cd..dd64ce652 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -102,18 +102,30 @@ class OPTPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="q_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="k_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="v_proj",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="out_proj",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -123,7 +135,14 @@ class OPTPolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=OPTDecoder,
@@ -272,6 +291,7 @@ class OPTForCausalLMPolicy(OPTPolicy):
                     kwargs=dict(
                         gather_output=not self.shard_config.parallel_output,
                         make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
+                        fp8_communication=self.shard_config.fp8_communication,
                     ),
                 ),
                 policy=policy,
diff --git a/colossalai/shardformer/policies/qwen2.py b/colossalai/shardformer/policies/qwen2.py
index 362c14060..1b066200d 100644
--- a/colossalai/shardformer/policies/qwen2.py
+++ b/colossalai/shardformer/policies/qwen2.py
@@ -119,37 +119,37 @@ class Qwen2Policy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.o_proj",
                         target_module=Linear1D_Row,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.gate_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.up_proj",
                         target_module=Linear1D_Col,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.down_proj",
                         target_module=Linear1D_Row,
-                        kwargs=dict(seq_parallel_mode=sp_mode),
+                        kwargs=dict(seq_parallel_mode=sp_mode, fp8_communication=self.shard_config.fp8_communication),
                     ),
                 ],
             )
@@ -159,7 +159,14 @@ class Qwen2Policy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=Qwen2Model,
@@ -317,7 +324,11 @@ class Qwen2ForCausalLMPolicy(Qwen2Policy):
             new_item = {
                 Qwen2ForCausalLM: ModulePolicyDescription(
                     sub_module_replacement=[
-                        SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col)
+                        SubModuleReplacementDescription(
+                            suffix="lm_head",
+                            target_module=Linear1D_Col,
+                            kwargs=dict(fp8_communication=self.shard_config.fp8_communication),
+                        )
                     ],
                     method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
                 )
@@ -366,7 +377,9 @@ class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
                 Qwen2ForSequenceClassification: ModulePolicyDescription(
                     sub_module_replacement=[
                         SubModuleReplacementDescription(
-                            suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                            suffix="score",
+                            target_module=Linear1D_Col,
+                            kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
                         )
                     ]
                 )
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index 53faf8997..674fe5e58 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -43,19 +43,29 @@ class SamPolicy(Policy):
                         target_module=col_nn.FusedLinear1D_Col,
                         kwargs={
                             "n_fused": 3,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attn.proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.lin1",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.lin2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -68,58 +78,100 @@ class SamPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_token_to_image.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_token_to_image.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_token_to_image.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_token_to_image.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.lin1",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="mlp.lin2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_image_to_token.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_image_to_token.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_image_to_token.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="cross_attn_image_to_token.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -132,18 +184,30 @@ class SamPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="final_attn_token_to_image.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="final_attn_token_to_image.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="final_attn_token_to_image.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="final_attn_token_to_image.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 0b594678c..84b5d9594 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -117,23 +117,38 @@ class T5BasePolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="q",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="k",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="v",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="o",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="relative_attention_bias",
                         target_module=Embedding1D,
-                        kwargs=dict(gather_output=False),
+                        kwargs=dict(
+                            gather_output=False,
+                            fp8_communication=self.shard_config.fp8_communication,
+                        ),
                         ignore_if_not_exist=True,
                     ),
                 ],
@@ -151,13 +166,24 @@ class T5BasePolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="wi_0 ",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="wi_1",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
-                        suffix="wo", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                        suffix="wo",
+                        target_module=Linear1D_Col,
+                        kwargs=dict(
+                            gather_output=True,
+                            fp8_communication=self.shard_config.fp8_communication,
+                        ),
                     ),
                     SubModuleReplacementDescription(
                         suffix="dropout",
@@ -170,10 +196,16 @@ class T5BasePolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="wi",
                         target_module=Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="wo",
                         target_module=Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="dropout",
@@ -187,7 +219,14 @@ class T5BasePolicy(Policy):
                 description=SubModuleReplacementDescription(
                     suffix="embed_tokens",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=T5Stack,
@@ -407,7 +446,14 @@ class T5ModelPolicy(T5BasePolicy):
                 description=SubModuleReplacementDescription(
                     suffix="shared",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=T5Model,
@@ -451,7 +497,14 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
                 description=SubModuleReplacementDescription(
                     suffix="shared",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=T5ForConditionalGeneration,
@@ -465,6 +518,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
                     kwargs={
                         "gather_output": True,
                         "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                        "fp8_communication": self.shard_config.fp8_communication,
                     },
                 ),
                 policy=policy,
@@ -539,7 +593,14 @@ class T5EncoderPolicy(T5BasePolicy):
                 description=SubModuleReplacementDescription(
                     suffix="shared",
                     target_module=embedding_cls,
-                    kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                    kwargs=(
+                        {
+                            "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        }
+                        if self.shard_config.enable_tensor_parallelism
+                        else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                    ),
                 ),
                 policy=policy,
                 target_key=T5EncoderModel,
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 069ad0c26..07202094f 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -70,14 +70,23 @@ class ViTPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attention.attention.query",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.key",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.value",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.attention.dropout",
@@ -86,6 +95,9 @@ class ViTPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="attention.output.dense",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="attention.output.dropout",
@@ -96,11 +108,15 @@ class ViTPolicy(Policy):
                         target_module=col_nn.Linear1D_Col,
                         kwargs={
                             "skip_bias_add": self.enable_bias_gelu_fused,
+                            "fp8_communication": self.shard_config.fp8_communication,
                         },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output.dense",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="output.dropout",
@@ -215,7 +231,9 @@ class ViTForImageClassificationPolicy(ViTPolicy):
                 ViTForImageClassification: ModulePolicyDescription(
                     sub_module_replacement=[
                         SubModuleReplacementDescription(
-                            suffix="classifier", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                            suffix="classifier",
+                            target_module=Linear1D_Col,
+                            kwargs=dict(gather_output=True, fp8_communication=self.shard_config.fp8_communication),
                         )
                     ]
                 )
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index 441e512bb..7a1f146d5 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -91,26 +91,44 @@ class WhisperPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc1",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -128,42 +146,72 @@ class WhisperPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="self_attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="self_attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="encoder_attn.q_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="encoder_attn.k_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="encoder_attn.v_proj",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="encoder_attn.out_proj",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc1",
                         target_module=col_nn.Linear1D_Col,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                     SubModuleReplacementDescription(
                         suffix="fc2",
                         target_module=col_nn.Linear1D_Row,
+                        kwargs={
+                            "fp8_communication": self.shard_config.fp8_communication,
+                        },
                     ),
                 ],
             )
@@ -174,7 +222,14 @@ class WhisperPolicy(Policy):
                     SubModuleReplacementDescription(
                         suffix="embed_tokens",
                         target_module=embedding_cls,
-                        kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
+                        kwargs=(
+                            {
+                                "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                                "fp8_communication": self.shard_config.fp8_communication,
+                            }
+                            if self.shard_config.enable_tensor_parallelism
+                            else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+                        ),
                     ),
                 ],
                 policy=policy,
@@ -303,6 +358,7 @@ class WhisperPolicy(Policy):
                     kwargs={
                         "gather_output": True,
                         "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+                        "fp8_communication": self.shard_config.fp8_communication,
                     },
                 ),
                 policy=base_policy,