[inference]Add alibi to flash attn function (#5678)

* add alibi to flash attn function

* rm redundant modifications
pull/5679/head
yuehuayingxueluo 2024-04-30 19:35:05 +08:00 committed by GitHub
parent ef8e4ffe31
commit f79963199c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 13 deletions

View File

@ -121,9 +121,7 @@ class InferenceEngine:
casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = (
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
)
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
else:
model = _supported_models[arch](hf_config)
else:

View File

@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
TypeError(
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
)
if use_cuda_kernel:
if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
@ -137,6 +136,7 @@ class NopadBaichuanAttention(ParallelModule):
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
slopes_start : slopes_start + num_heads
].contiguous()
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
@staticmethod
def from_native_module(
@ -268,19 +268,13 @@ class NopadBaichuanAttention(ParallelModule):
block_size = k_cache.size(-2)
if is_prompts:
if (
not is_verifier
and use_cuda_kernel
and query_states.dtype != torch.float32
and use_flash_attn2
and not self.use_alibi_attn
):
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
# flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
if not self.use_alibi_attn:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
@ -292,6 +286,7 @@ class NopadBaichuanAttention(ParallelModule):
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
alibi_slopes=self.alibi_slopes,
)
attn_output = attn_output.view(token_nums, -1)
else: