Browse Source

[deepseek] replace attn (a workaround for bug in transformers)

colossalchat
hxwang 4 months ago committed by Hongxin Liu
parent
commit
c3dc9b4dba
  1. 32
      colossalai/shardformer/policies/deepseek.py
  2. 1
      tests/test_shardformer/test_model/test_shard_deepseek_ghz.py

32
colossalai/shardformer/policies/deepseek.py

@ -1,4 +1,3 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
)
if self.shard_config.enable_flash_attention:
warnings.warn(
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False."
# NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now
from transformers.dynamic_module_utils import get_class_from_dynamic_module
flash_attn_cls = get_class_from_dynamic_module(
"deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2",
"deepseek-ai/deepseek-moe-16b-base",
)
self.shard_config.enable_flash_attention = False
class TargetFlashAttn:
def __init__(self):
raise RuntimeError("This class should not be instantiated")
@staticmethod
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module:
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx)
flash_attn_module.q_proj = original_attn.q_proj
flash_attn_module.k_proj = original_attn.k_proj
flash_attn_module.v_proj = original_attn.v_proj
flash_attn_module.o_proj = original_attn.o_proj
flash_attn_module.rotary_emb = original_attn.rotary_emb
return flash_attn_module
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="self_attn",
target_module=TargetFlashAttn,
),
policy=policy,
target_key="DeepseekDecoderLayer",
)
return policy
def postprocess(self):

1
tests/test_shardformer/test_model/test_shard_deepseek_ghz.py

@ -220,6 +220,7 @@ def check_deepseek(rank, world_size, port):
run_deepseek_test()
@pytest.mark.skip("redundant")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()

Loading…
Cancel
Save