diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 287a0e2..49578d7 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Tuple import torch import torch.distributed as dist +import torch.nn.functional as F from einops import rearrange from flash_attn.modules.mha import ( CrossAttention,