|
|
|
@ -1,4 +1,3 @@
|
|
|
|
|
import warnings |
|
|
|
|
from functools import partial |
|
|
|
|
from typing import Callable, Dict, List, Union |
|
|
|
|
|
|
|
|
@ -195,11 +194,36 @@ class DeepseekPolicy(Policy):
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if self.shard_config.enable_flash_attention: |
|
|
|
|
warnings.warn( |
|
|
|
|
"Flash attention has already been replaced in deepseek, and now set enable_flash_attention = False." |
|
|
|
|
# NOTE: there is a bug for toggling flash attention in AutoModel, which has to be used for deepseek right now |
|
|
|
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module |
|
|
|
|
|
|
|
|
|
flash_attn_cls = get_class_from_dynamic_module( |
|
|
|
|
"deepseek-ai/deepseek-moe-16b-base--modeling_deepseek.DeepseekFlashAttention2", |
|
|
|
|
"deepseek-ai/deepseek-moe-16b-base", |
|
|
|
|
) |
|
|
|
|
self.shard_config.enable_flash_attention = False |
|
|
|
|
|
|
|
|
|
class TargetFlashAttn: |
|
|
|
|
def __init__(self): |
|
|
|
|
raise RuntimeError("This class should not be instantiated") |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def from_native_module(original_attn: nn.Module, *args, **kwargs) -> nn.Module: |
|
|
|
|
flash_attn_module = flash_attn_cls(original_attn.config, original_attn.layer_idx) |
|
|
|
|
flash_attn_module.q_proj = original_attn.q_proj |
|
|
|
|
flash_attn_module.k_proj = original_attn.k_proj |
|
|
|
|
flash_attn_module.v_proj = original_attn.v_proj |
|
|
|
|
flash_attn_module.o_proj = original_attn.o_proj |
|
|
|
|
flash_attn_module.rotary_emb = original_attn.rotary_emb |
|
|
|
|
return flash_attn_module |
|
|
|
|
|
|
|
|
|
self.append_or_create_submodule_replacement( |
|
|
|
|
description=SubModuleReplacementDescription( |
|
|
|
|
suffix="self_attn", |
|
|
|
|
target_module=TargetFlashAttn, |
|
|
|
|
), |
|
|
|
|
policy=policy, |
|
|
|
|
target_key="DeepseekDecoderLayer", |
|
|
|
|
) |
|
|
|
|
return policy |
|
|
|
|
|
|
|
|
|
def postprocess(self): |
|
|
|
|