mirror of https://github.com/hpcaitech/ColossalAI
[inference]Add alibi to flash attn function (#5678)
* add alibi to flash attn function * rm redundant modificationspull/5679/head
parent
ef8e4ffe31
commit
f79963199c
|
@ -121,9 +121,7 @@ class InferenceEngine:
|
||||||
casuallm = _supported_models[arch](hf_config)
|
casuallm = _supported_models[arch](hf_config)
|
||||||
if isinstance(casuallm, AutoModelForCausalLM):
|
if isinstance(casuallm, AutoModelForCausalLM):
|
||||||
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
||||||
model = (
|
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
|
||||||
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
model = _supported_models[arch](hf_config)
|
model = _supported_models[arch](hf_config)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
|
||||||
TypeError(
|
TypeError(
|
||||||
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
|
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cuda_kernel:
|
if use_cuda_kernel:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
|
||||||
|
@ -137,6 +136,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
||||||
slopes_start : slopes_start + num_heads
|
slopes_start : slopes_start + num_heads
|
||||||
].contiguous()
|
].contiguous()
|
||||||
|
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(
|
def from_native_module(
|
||||||
|
@ -268,19 +268,13 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
if (
|
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||||
not is_verifier
|
|
||||||
and use_cuda_kernel
|
|
||||||
and query_states.dtype != torch.float32
|
|
||||||
and use_flash_attn2
|
|
||||||
and not self.use_alibi_attn
|
|
||||||
):
|
|
||||||
# flash attn 2 currently only supports FP16/BF16.
|
# flash attn 2 currently only supports FP16/BF16.
|
||||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
if not self.use_alibi_attn:
|
||||||
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||||
inference_ops.context_kv_cache_memcpy(
|
inference_ops.context_kv_cache_memcpy(
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
attn_output = flash_attn_varlen_func(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
@ -292,6 +286,7 @@ class NopadBaichuanAttention(ParallelModule):
|
||||||
dropout_p=0.0,
|
dropout_p=0.0,
|
||||||
softmax_scale=sm_scale,
|
softmax_scale=sm_scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.view(token_nums, -1)
|
attn_output = attn_output.view(token_nums, -1)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue