Browse Source

[misc] fix ci failure: change default value to false in moe plugin

colossalchat
haze188 4 months ago committed by Hongxin Liu
parent
commit
70793ce9ed
  1. 2
      colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
  2. 2
      colossalai/shardformer/policies/deepseek.py
  3. 3
      tests/test_shardformer/test_model/test_shard_deepseek.py

2
colossalai/booster/plugin/moe_hybrid_parallel_plugin.py

@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
overlap_communication: bool = False,
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,

2
colossalai/shardformer/policies/deepseek.py

@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from transformers.utils import is_flash_attn_greater_or_equal_2_10
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
@ -206,6 +207,7 @@ class DeepseekPolicy(Policy):
@staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
original_attn.__class__ = flash_attn_cls
original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
return original_attn
self.append_or_create_submodule_replacement(

3
tests/test_shardformer/test_model/test_shard_deepseek.py

@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
zero_stage=stage,
enable_sequence_parallelism=sp_size > 1,
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
enable_flash_attention=sp_size > 1,
overlap_communication=False,
initial_scale=1,
precision=precision,
@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
# use checkpoint to load sharded zero model
model_dir = "./test_mixtral"
model_dir = "./test_deepseek"
if rank == world_size - 1:
os.makedirs(model_dir, exist_ok=True)

Loading…
Cancel
Save