mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 4908 (#4918)
Co-authored-by: github-actions <github-actions@github.com>pull/4934/head
parent
4f68b3f10c
commit
a41cf88e9b
|
@ -6,25 +6,20 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_kvpacked_func
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaRMSNorm,
|
||||
LlamaAttention,
|
||||
LlamaModel,
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from einops import rearrange
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
)
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
@ -65,7 +60,7 @@ def attention_forward(
|
|||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""
|
||||
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.
|
||||
|
|
Loading…
Reference in New Issue