mirror of https://github.com/hpcaitech/ColossalAI
[misc] fix ci failure: change default value to false in moe plugin
parent
12d043ca00
commit
70793ce9ed
|
@ -154,7 +154,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||||
zero_bucket_size_in_m: int = 12,
|
zero_bucket_size_in_m: int = 12,
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
communication_dtype: Optional[torch.dtype] = None,
|
communication_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_communication: bool = True,
|
overlap_communication: bool = False,
|
||||||
custom_policy: Policy = None,
|
custom_policy: Policy = None,
|
||||||
pp_style: str = "1f1b",
|
pp_style: str = "1f1b",
|
||||||
num_model_chunks: int = 1,
|
num_model_chunks: int = 1,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Union
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
from transformers.utils import is_flash_attn_greater_or_equal_2_10
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
|
||||||
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer.embedding import PaddingEmbedding, VocabParallelEmbedding1D
|
||||||
|
@ -206,6 +207,7 @@ class DeepseekPolicy(Policy):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
||||||
original_attn.__class__ = flash_attn_cls
|
original_attn.__class__ = flash_attn_cls
|
||||||
|
original_attn._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
return original_attn
|
return original_attn
|
||||||
|
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
|
|
@ -60,6 +60,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
zero_stage=stage,
|
zero_stage=stage,
|
||||||
enable_sequence_parallelism=sp_size > 1,
|
enable_sequence_parallelism=sp_size > 1,
|
||||||
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
sequence_parallelism_mode="all_to_all" if sp_size > 1 else None,
|
||||||
|
enable_flash_attention=sp_size > 1,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
precision=precision,
|
precision=precision,
|
||||||
|
@ -161,7 +162,7 @@ def run_zero_with_original_model(config: Tuple[int, ...]):
|
||||||
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
|
||||||
|
|
||||||
# use checkpoint to load sharded zero model
|
# use checkpoint to load sharded zero model
|
||||||
model_dir = "./test_mixtral"
|
model_dir = "./test_deepseek"
|
||||||
if rank == world_size - 1:
|
if rank == world_size - 1:
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue