mirror of https://github.com/hpcaitech/ColossalAI
[deepseek] replace attn (a workaround for bug in transformers)
parent
6c39f0b144
commit
c3dc9b4dba
|
@ -1,4 +1,3 @@
|
||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Dict, List, Union
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_flash_attention:
|
if self.shard_config.enable_flash_attention:
|
||||||
warnings.warn(
|
# NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
|
||||||
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||||
)
|
|
||||||
self.shard_config.enable_flash_attention = False
|
|
||||||
|
|
||||||
|
flash_attn_cls = get_class_from_dynamic_module(
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
|
||||||
|
"deepseek-ai/deepseek-moe-16b-base",
|
||||||
|
)
|
||||||
|
|
||||||
|
class TargetFlashAttn:
|
||||||
|
def __init__(self):
|
||||||
|
raise RuntimeError("This class should not be instantiated")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
|
||||||
|
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
|
||||||
|
flash_attn_module.q_proj = original_attn.q_proj
|
||||||
|
flash_attn_module.k_proj = original_attn.k_proj
|
||||||
|
flash_attn_module.v_proj = original_attn.v_proj
|
||||||
|
flash_attn_module.o_proj = original_attn.o_proj
|
||||||
|
flash_attn_module.rotary_emb = original_attn.rotary_emb
|
||||||
|
return flash_attn_module
|
||||||
|
|
||||||
|
self.append_or_create_submodule_replacement(
|
||||||
|
description=SubModuleReplacementDescription(
|
||||||
|
suffix="self_attn",
|
||||||
|
target_module=TargetFlashAttn,
|
||||||
|
),
|
||||||
|
policy=policy,
|
||||||
|
target_key="DeepseekDecoderLayer",
|
||||||
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
|
|
|
@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port):
|
||||||
run_deepseek_test()
|
run_deepseek_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("redundant")
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
|
|
Loading…
Reference in New Issue