diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 067d3c981..73fe7df9b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -121,9 +121,7 @@ class InferenceEngine: casuallm = _supported_models[arch](hf_config) if isinstance(casuallm, AutoModelForCausalLM): # NOTE(caidi) It's necessary to add half() here, otherwise baichuan13B will overflow the memory. - model = ( - AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half().cuda() - ) + model = AutoModelForCausalLM.from_pretrained(model_or_path, trust_remote_code=True).half() else: model = _supported_models[arch](hf_config) else: diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index ca8a0e696..e6b39ccfa 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -79,7 +79,6 @@ def baichuan_rmsnorm_forward( TypeError( "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'." ) - if use_cuda_kernel: if residual is not None: inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps) @@ -137,6 +136,7 @@ class NopadBaichuanAttention(ParallelModule): self.alibi_slopes = get_alibi_slopes(config.num_attention_heads, device=attn_qproj_w.device)[ slopes_start : slopes_start + num_heads ].contiguous() + self.alibi_slopes = nn.Parameter(self.alibi_slopes) @staticmethod def from_native_module( @@ -268,19 +268,13 @@ class NopadBaichuanAttention(ParallelModule): block_size = k_cache.size(-2) if is_prompts: - if ( - not is_verifier - and use_cuda_kernel - and query_states.dtype != torch.float32 - and use_flash_attn2 - and not self.use_alibi_attn - ): + if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: # flash attn 2 currently only supports FP16/BF16. - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + if not self.use_alibi_attn: + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) inference_ops.context_kv_cache_memcpy( key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len ) - attn_output = flash_attn_varlen_func( query_states, key_states, @@ -292,6 +286,7 @@ class NopadBaichuanAttention(ParallelModule): dropout_p=0.0, softmax_scale=sm_scale, causal=True, + alibi_slopes=self.alibi_slopes, ) attn_output = attn_output.view(token_nums, -1) else: