[deepseek] replace attn (a workaround for bug in transformers)

colossalchat
hxwang 2024-07-23 12:56:58 +00:00 committed by Hongxin Liu
parent 6c39f0b144
commit c3dc9b4dba
2 changed files with 30 additions and 5 deletions

View File

@ -1,4 +1,3 @@
import warnings
from functools import partial from functools import partial
from typing import Callable, Dict, List, Union from typing import Callable, Dict, List, Union
@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
) )
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
warnings.warn( # NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." from transformers.dynamic_module_utils import get_class_from_dynamic_module
)
self.shard_config.enable_flash_attention = False
flash_attn_cls = get_class_from_dynamic_module(
"deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
"deepseek-ai/deepseek-moe-16b-base",
)
class TargetFlashAttn:
def __init__(self):
raise RuntimeError("This class should not be instantiated")
@staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
flash_attn_module.q_proj = original_attn.q_proj
flash_attn_module.k_proj = original_attn.k_proj
flash_attn_module.v_proj = original_attn.v_proj
flash_attn_module.o_proj = original_attn.o_proj
flash_attn_module.rotary_emb = original_attn.rotary_emb
return flash_attn_module
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="self_attn",
target_module=TargetFlashAttn,
),
policy=policy,
target_key="DeepseekDecoderLayer",
)
return policy return policy
def postprocess(self): def postprocess(self):

View File

@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port):
run_deepseek_test() run_deepseek_test()
@pytest.mark.skip("redundant")
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@clear_cache_before_run() @clear_cache_before_run()