mirror of https://github.com/hpcaitech/ColossalAI
[inference]Add alibi to flash attn function (#5678)
* add alibi to flash attn function * rm redundant modificationspull/5679/head
parent
ef8e4ffe31
commit
f79963199c
|
@ -121,9 +121,7 @@ class InferenceEngine:
|
|||
casuallm = _supported_models[arch](hf_config)
|
||||
if isinstance(casuallm, AutoModelForCausalLM):
|
||||
# NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory.
|
||||
model = (
|
||||
AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda()
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half()
|
||||
else:
|
||||
model = _supported_models[arch](hf_config)
|
||||
else:
|
||||
|
|
|
@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward(
|
|||
TypeError(
|
||||
"Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'."
|
||||
)
|
||||
|
||||
if use_cuda_kernel:
|
||||
if residual is not None:
|
||||
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps)
|
||||
|
@ -137,6 +136,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[
|
||||
slopes_start : slopes_start + num_heads
|
||||
].contiguous()
|
||||
self.alibi_slopes = nn.Parameter(self.alibi_slopes)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
|
@ -268,19 +268,13 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if (
|
||||
not is_verifier
|
||||
and use_cuda_kernel
|
||||
and query_states.dtype != torch.float32
|
||||
and use_flash_attn2
|
||||
and not self.use_alibi_attn
|
||||
):
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
|
@ -292,6 +286,7 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue