Browse Source

[misc] fix ci failure: change default value to false in moe plugin

colossalchat
haze188 4 months ago committed by Hongxin Liu
parent
commit
70793ce9ed
  1. 2
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 2
      colossalai/shardformer/policies/deepseek.py
  3. 3
      tests/test_shardformer/test_model/test_shard_deepseek.py

2
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
zero_bucket_size_in_m: int = 12, zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False, cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None, communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True, overlap_communication: bool = False,
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,

2
colossalai/shardformer/policies/deepseek.py

@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from transformers.utils import is_flash_attn_greater_or_equal_2_10
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
@ -206,6 +207,7 @@ class DeepseekPolicy(Policy):
@staticmethod @staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
original_attn.__class__ = flash_attn_cls original_attn.__class__ = flash_attn_cls
original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
return original_attn return original_attn
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(

3
tests/test_shardformer/test_model/test_shard_deepseek.py

@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage, zero_stage=stage,
enable_sequence_parallelism=sp_size > 1, enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None, sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
enable_flash_attention=sp_size > 1,
overlap_communication=False, overlap_communication=False,
initial_scale=1, initial_scale=1,
precision=precision, precision=precision,
@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model # use checkpoint to load sharded zero model
model_dir = "./test_mixtral" model_dir = "./test_deepseek"
if rank == world_size - 1: if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)

Loading…
Cancel
Save