[inference]Add alibi to flash attn function (#5678)

* add alibi to flash attn function

* rm redundant modifications
pull/5679/head
yuehuayingxueluo 2024-04-30 19:35:05 +08:00 committed by GitHub
parent ef8e4ffe31
commit f79963199c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 13 deletions

View File

@ -121,9 +121,7 @@ class InferenceEngine:
casuallm = _supported_models[arch](hf_config) casuallm = _supported_models[arch](hf_config)
if isinstance(casuallm, AutoModelForCausalLM): if isinstance(casuallm, AutoModelForCausalLM):
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory. # NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
model = ( model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
)
else: else:
model = _supported_models[arch](hf_config) model = _supported_models[arch](hf_config)
else: else:

View File

@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
TypeError( TypeError(
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'." "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
) )
if use_cuda_kernel: if use_cuda_kernel:
if residual is not None: if residual is not None:
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps) inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
@ -137,6 +136,7 @@ class NopadBaichuanAttention(ParallelModule):
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
slopes_start : slopes_start + num_heads slopes_start : slopes_start + num_heads
].contiguous() ].contiguous()
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
@staticmethod @staticmethod
def from_native_module( def from_native_module(
@ -268,19 +268,13 @@ class NopadBaichuanAttention(ParallelModule):
block_size = k_cache.size(-2) block_size = k_cache.size(-2)
if is_prompts: if is_prompts:
if ( if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
not is_verifier
and use_cuda_kernel
and query_states.dtype != torch.float32
and use_flash_attn2
and not self.use_alibi_attn
):
# flash attn 2 currently only supports FP16/BF16. # flash attn 2 currently only supports FP16/BF16.
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) if not self.use_alibi_attn:
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
inference_ops.context_kv_cache_memcpy( inference_ops.context_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
) )
attn_output = flash_attn_varlen_func( attn_output = flash_attn_varlen_func(
query_states, query_states,
key_states, key_states,
@ -292,6 +286,7 @@ class NopadBaichuanAttention(ParallelModule):
dropout_p=0.0, dropout_p=0.0,
softmax_scale=sm_scale, softmax_scale=sm_scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes,
) )
attn_output = attn_output.view(token_nums, -1) attn_output = attn_output.view(token_nums, -1)
else: else: