mirror of https://github.com/hpcaitech/ColossalAI
[misc] fix ci failure: change default value to false in moe plugin
parent
12d043ca00
commit
70793ce9ed
|
@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
zero_bucket_size_in_m: int = 12,
|
||||
cpu_offload: bool = False,
|
||||
communication_dtype: Optional[torch.dtype] = None,
|
||||
overlap_communication: bool = True,
|
||||
overlap_communication: bool = False,
|
||||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
|
|
|
@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||
|
@ -206,6 +207,7 @@ class DeepseekPolicy(Policy):
|
|||
@staticmethod
|
||||
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
original_attn.__class__ = flash_attn_cls
|
||||
original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
return original_attn
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
|
|
@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
zero_stage=stage,
|
||||
enable_sequence_parallelism=sp_size > 1,
|
||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||
enable_flash_attention=sp_size > 1,
|
||||
overlap_communication=False,
|
||||
initial_scale=1,
|
||||
precision=precision,
|
||||
|
@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
|||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||
|
||||
# use checkpoint to load sharded zero model
|
||||
model_dir = "./test_mixtral"
|
||||
model_dir = "./test_deepseek"
|
||||
if rank == world_size - 1:
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
|
|
Loading…
Reference in New Issue