diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index f489f36..a5362c4 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -44,7 +44,7 @@ try: from flash_attn.modules.mlp import ParallelFusedMLP from flash_attn.ops.layer_norm import dropout_add_layer_norm except ImportError: - raise ImportError("Please check your flash_attn version >= 2.0.0.") + pass MODEL_TYPE = "LLAMA"