mirror of https://github.com/hpcaitech/ColossalAI
[deepseek] replace attn (a workaround for bug in transformers)
parent
6c39f0b144
commit
c3dc9b4dba
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
|
@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
|
|||
)
|
||||
|
||||
if self.shard_config.enable_flash_attention:
|
||||
warnings.warn(
|
||||
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
||||
)
|
||||
self.shard_config.enable_flash_attention = False
|
||||
# NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
|
||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
|
||||
flash_attn_cls = get_class_from_dynamic_module(
|
||||
"deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
|
||||
"deepseek-ai/deepseek-moe-16b-base",
|
||||
)
|
||||
|
||||
class TargetFlashAttn:
|
||||
def __init__(self):
|
||||
raise RuntimeError("This class should not be instantiated")
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
||||
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
|
||||
flash_attn_module.q_proj = original_attn.q_proj
|
||||
flash_attn_module.k_proj = original_attn.k_proj
|
||||
flash_attn_module.v_proj = original_attn.v_proj
|
||||
flash_attn_module.o_proj = original_attn.o_proj
|
||||
flash_attn_module.rotary_emb = original_attn.rotary_emb
|
||||
return flash_attn_module
|
||||
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=TargetFlashAttn,
|
||||
),
|
||||
policy=policy,
|
||||
target_key="DeepseekDecoderLayer",
|
||||
)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
|
|
@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port):
|
|||
run_deepseek_test()
|
||||
|
||||
|
||||
@pytest.mark.skip("redundant")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
|
|
Loading…
Reference in New Issue