diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 7f6608086..4c3aece9d 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): zero_bucket_size_in_m: int = 12, cpu_offload: bool = False, communication_dtype: Optional[torch.dtype] = None, - overlap_communication: bool = True, + overlap_communication: bool = False, custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, diff --git a/colossalai/shardformer/policies/deepseek.py b/colossalai/shardformer/policies/deepseek.py index d729a4ecc..605f69c4a 100644 --- a/colossalai/shardformer/policies/deepseek.py +++ b/colossalai/shardformer/policies/deepseek.py @@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union import torch.nn as nn from torch import Tensor from torch.nn import Module +from transformers.utils import is_flash_attn_greater_or_equal_2_10 from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D @@ -206,6 +207,7 @@ class DeepseekPolicy(Policy): @staticmethod def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: original_attn.__class__ = flash_attn_cls + original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() return original_attn self.append_or_create_submodule_replacement( diff --git a/tests/test_shardformer/test_model/test_shard_deepseek.py b/tests/test_shardformer/test_model/test_shard_deepseek.py index 709963613..187c642da 100644 --- a/tests/test_shardformer/test_model/test_shard_deepseek.py +++ b/tests/test_shardformer/test_model/test_shard_deepseek.py @@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): zero_stage=stage, enable_sequence_parallelism=sp_size > 1, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, + enable_flash_attention=sp_size > 1, overlap_communication=False, initial_scale=1, precision=precision, @@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]): assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) # use checkpoint to load sharded zero model - model_dir = "./test_mixtral" + model_dir = "./test_deepseek" if rank == world_size - 1: os.makedirs(model_dir, exist_ok=True)