From 229db4bc1623e74f2007074db35cb198e6c3b51a Mon Sep 17 00:00:00 2001
From: hxwang <wang1570@e.ntu.edu.sg>
Date: Tue, 2 Jul 2024 09:02:21 +0000
Subject: [PATCH] [test] add mixtral for sequence classification

---
 .../shardformer/policies/auto_policy.py       |   3 +
 colossalai/shardformer/policies/mixtral.py    | 131 +++++++++++++++---
 2 files changed, 118 insertions(+), 16 deletions(-)

diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index ae9f3603c..1e0af031a 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -200,6 +200,9 @@ _POLICY_LIST = {
     "transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM": PolicyLocation(
         file_name="mixtral", class_name="MixtralForCausalLMPolicy"
     ),
+    "transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification": PolicyLocation(
+        file_name="mixtral", class_name="MixtralForSequenceClassificationPolicy"
+    ),
     # Qwen2
     "transformers.models.qwen2.modeling_qwen2.Qwen2Model": PolicyLocation(
         file_name="qwen2", class_name="Qwen2ModelPolicy"
diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py
index ad93e9469..e3cc48043 100644
--- a/colossalai/shardformer/policies/mixtral.py
+++ b/colossalai/shardformer/policies/mixtral.py
@@ -1,3 +1,4 @@
+import warnings
 from functools import partial
 from typing import Callable, Dict, List, Union
 
@@ -39,20 +40,81 @@ class MixtralPolicy(Policy):
             )
 
         if self.shard_config.enable_tensor_parallelism:
-            raise NotImplementedError("Tensor parallelism is not supported for Mixtral model now.")
-        if getattr(self.shard_config, "ep_group", None) is not None:
-            # expert parallel
-            self.append_or_create_submodule_replacement(
-                description=[
-                    SubModuleReplacementDescription(
-                        suffix="block_sparse_moe",
-                        target_module=EPMixtralSparseMoeBlock,
-                        kwargs={"ep_group": self.shard_config.ep_group},
-                    )
-                ],
-                policy=policy,
-                target_key=MixtralDecoderLayer,
-            )
+            raise NotImplementedError
+            # assert (
+            #     self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+            # ), f"The number of attention heads must be divisible by tensor parallel size."
+            # assert (
+            #     self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+            # ), f"The number of key_value heads must be divisible by tensor parallel size."
+            # decoder_attribute_replacement = {
+            #     "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+            #     "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+            #     "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+            #     // self.shard_config.tensor_parallel_size,
+            # }
+
+            # policy[MixtralDecoderLayer] = ModulePolicyDescription(
+            #     attribute_replacement=decoder_attribute_replacement,
+            #     sub_module_replacement=[
+            #         SubModuleReplacementDescription(
+            #             suffix="self_attn.q_proj",
+            #             target_module=Linear1D_Col,
+            #             kwargs={
+            #                 'process_group': self.shard_config.tensor_parallel_process_group,
+            #             }
+            #         ),
+            #         SubModuleReplacementDescription(
+            #             suffix="self_attn.k_proj",
+            #             target_module=Linear1D_Col,
+            #             kwargs={
+            #                 'process_group': self.shard_config.tensor_parallel_process_group,
+            #             }
+            #         ),
+            #         SubModuleReplacementDescription(
+            #             suffix="self_attn.v_proj",
+            #             target_module=Linear1D_Col,
+            #             kwargs={
+            #                 'process_group': self.shard_config.tensor_parallel_process_group,
+            #             }
+            #         ),
+            #         SubModuleReplacementDescription(
+            #             suffix="self_attn.o_proj",
+            #             target_module=Linear1D_Row,
+            #             kwargs={
+            #                 'process_group': self.shard_config.tensor_parallel_process_group,
+            #             }
+            #         ),
+            #         # SubModuleReplacementDescription(
+            #         #     suffix="mlp.gate_proj",
+            #         #     target_module=Linear1D_Col,
+            #         # ),
+            #         # SubModuleReplacementDescription(
+            #         #     suffix="mlp.up_proj",
+            #         #     target_module=Linear1D_Col,
+            #         # ),
+            #         # SubModuleReplacementDescription(
+            #         #     suffix="mlp.down_proj",
+            #         #     target_module=Linear1D_Row,
+            #         # ),
+            #     ],
+            # )
+
+        if getattr(self.shard_config, "ep_group", None) is None:
+            raise ValueError("You must pass in ep_group via shard_config for expert parallel!")
+
+        # expert parallel
+        self.append_or_create_submodule_replacement(
+            description=[
+                SubModuleReplacementDescription(
+                    suffix="block_sparse_moe",
+                    target_module=EPMixtralSparseMoeBlock,
+                    kwargs={"ep_group": self.shard_config.ep_group},
+                )
+            ],
+            policy=policy,
+            target_key=MixtralDecoderLayer,
+        )
 
         # optimization configuration
         if self.shard_config.enable_fused_normalization:
@@ -81,7 +143,7 @@ class MixtralPolicy(Policy):
             )
 
         if self.shard_config.enable_flash_attention:
-            raise NotImplementedError("Flash attention has already been replaced in mixtral.")
+            warnings.warn("Flash attention is natively supported in transformers, will ignore the flag.")
 
         return policy
 
@@ -150,7 +212,7 @@ class MixtralModelPolicy(MixtralPolicy):
         return held_layers
 
     def get_shared_params(self) -> List[Dict[int, Tensor]]:
-        """No shared params in llama model"""
+        """No shared params in mixtral model"""
         return []
 
 
@@ -206,3 +268,40 @@ class MixtralForCausalLMPolicy(MixtralPolicy):
                     }
                 ]
         return []
+
+
+class MixtralForSequenceClassificationPolicy(MixtralPolicy):
+    def module_policy(self):
+        from transformers import MixtralForSequenceClassification
+
+        policy = super().module_policy()
+
+        if self.shard_config.enable_tensor_parallelism:
+            # add a new item for sequence classification
+            new_item = {
+                MixtralForSequenceClassification: ModulePolicyDescription(
+                    sub_module_replacement=[
+                        SubModuleReplacementDescription(
+                            suffix="score", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
+                        )
+                    ]
+                )
+            }
+            policy.update(new_item)
+
+        if self.pipeline_stage_manager:
+            raise NotImplementedError
+
+        return policy
+
+    def get_held_layers(self) -> List[Module]:
+        """Get pipeline layers for current stage."""
+        stage_manager = self.pipeline_stage_manager
+        held_layers = super().get_held_layers()
+        if stage_manager.is_last_stage(ignore_chunk=True):
+            held_layers.append(self.model.score)
+        return held_layers
+
+    def get_shared_params(self) -> List[Dict[int, Tensor]]:
+        """No shared params in llama for sequence classification model"""
+        return []