mirror of https://github.com/hpcaitech/ColossalAI
parent
04386d9eff
commit
eec77e5702
|
@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
|||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData, ModelInferenceConfig
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.sampler import search_tokens
|
||||
|
@ -72,9 +72,8 @@ class InferenceEngine:
|
|||
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.model_inference_config = inference_config.to_model_inference_config()
|
||||
|
||||
self.init_model(model_or_path, model_policy, self.model_inference_config)
|
||||
self.init_model(model_or_path, model_policy)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
self.generation_config_dict = self.generation_config.to_dict()
|
||||
|
@ -113,7 +112,6 @@ class InferenceEngine:
|
|||
self,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
model_policy: Union[Policy, Type[Policy]] = None,
|
||||
model_inference_config: ModelInferenceConfig = None,
|
||||
):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
@ -178,7 +176,6 @@ class InferenceEngine:
|
|||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
model_inference_config,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
@ -299,7 +296,6 @@ class InferenceEngine:
|
|||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
model_inference_config: ModelInferenceConfig,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
|
|
|
@ -8,15 +8,16 @@ from colossalai.inference.utils import can_use_flash_attn2
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
context_attention_unpadded,
|
||||
flash_decoding_attention,
|
||||
)
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetaData:
|
||||
class AttentionMetaData:
|
||||
query_states: torch.Tensor
|
||||
key_states: torch.Tensor
|
||||
value_states: torch.Tensor
|
||||
|
@ -32,7 +33,8 @@ class AttentionMetaData:
|
|||
output_tensor: torch.Tensor = None
|
||||
use_spec_dec: bool = False
|
||||
use_alibi_attn: bool = False
|
||||
|
||||
use_cuda_kernel: bool = False
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
|
@ -42,46 +44,30 @@ class AttentionBackend(ABC):
|
|||
@abstractmethod
|
||||
def decode(self, attn_metadatas: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class CudaAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec:
|
||||
token_nums = kwargs.get('token_nums', -1)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
max_seqlen_v=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=attn_metadata.query_states,
|
||||
k=attn_metadata.key_states,
|
||||
v=attn_metadata.value_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
v_cache=attn_metadata.v_cache,
|
||||
context_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=True,
|
||||
)
|
||||
token_nums = kwargs.get("token_nums", -1)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
cu_seqlens_q=attn_metadata.cu_seqlens,
|
||||
cu_seqlens_k=attn_metadata.cu_seqlens,
|
||||
max_seqlen_q=attn_metadata.kv_seq_len,
|
||||
max_seqlen_k=attn_metadata.kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=attn_metadata.sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
output_tensor = attn_metadata.output_tensor
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
|
@ -99,8 +85,8 @@ class CudaAttentionBackend(AttentionBackend):
|
|||
attn_metadata.sm_scale,
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
return context_attention_unpadded(
|
||||
|
@ -113,13 +99,14 @@ class TritonAttentionBackend(AttentionBackend):
|
|||
block_tables=attn_metadata.block_tables,
|
||||
block_size=attn_metadata.block_size,
|
||||
output=attn_metadata.output_tensor,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
max_seq_len=attn_metadata.kv_seq_len,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
use_new_kcache_layout=False,
|
||||
use_new_kcache_layout=attn_metadata.use_cuda_kernel,
|
||||
)
|
||||
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
fd_inter_tensor = kwargs.get('fd_inter_tensor', None)
|
||||
fd_inter_tensor = kwargs.get("fd_inter_tensor", None)
|
||||
return flash_decoding_attention(
|
||||
q=attn_metadata.query_states,
|
||||
k_cache=attn_metadata.k_cache,
|
||||
|
@ -131,16 +118,25 @@ class TritonAttentionBackend(AttentionBackend):
|
|||
output=attn_metadata.output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=attn_metadata.alibi_slopes,
|
||||
sm_scale=attn_metadata.sm_scale,
|
||||
kv_group_num=kwargs.get('num_key_value_groups', 0),
|
||||
q_len=kwargs.get('q_len', 1),
|
||||
kv_group_num=kwargs.get("num_key_value_groups", 1),
|
||||
q_len=kwargs.get("q_len", 1),
|
||||
)
|
||||
|
||||
|
||||
def get_attention_backend(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionBackend:
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
|
||||
|
||||
def get_attention_backend(
|
||||
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
|
||||
) -> AttentionBackend:
|
||||
"""
|
||||
Get the attention backend based on the inference configurations. Only when:
|
||||
1. using CUDA kernel (use_cuda_kernel=True)
|
||||
2. can use flash attention (flash-attn installed and dtype is fp16 or bf16)
|
||||
3. not using speculative decoding (currently cuda kernel not support speculative decoding)
|
||||
will the CUDA-kernel-based backend be used for attention layer computations. Otherwise, use Triton attention backend.
|
||||
"""
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionBackend()
|
||||
else:
|
||||
return TritonAttentionBackend()
|
||||
|
|
@ -13,60 +13,52 @@ from colossalai.kernel.triton import (
|
|||
|
||||
logger = get_dist_logger(__name__)
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
||||
class AttentionContext(ABC):
|
||||
|
||||
class PreAttentionBackend(ABC):
|
||||
@abstractmethod
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CudaAttentionContext(AttentionContext):
|
||||
|
||||
|
||||
class CudaPreAttentionBackend(PreAttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get('high_precision', False),
|
||||
)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
else:
|
||||
rotary_embedding(
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
)
|
||||
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
kwargs.get("high_precision", False),
|
||||
)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.cu_seqlens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.kv_seq_len,
|
||||
)
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if attn_metadata.use_alibi_attn:
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.high_precision,
|
||||
kwargs.get("high_precision", None),
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
|
@ -77,58 +69,63 @@ class CudaAttentionContext(AttentionContext):
|
|||
attn_metadata.sequence_lengths,
|
||||
attn_metadata.block_tables,
|
||||
)
|
||||
|
||||
|
||||
class TritonAttentionContext(AttentionContext):
|
||||
|
||||
|
||||
class TritonPreAttentionBackend(PreAttentionBackend):
|
||||
def prefill(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
|
||||
|
||||
def decode(self, attn_metadata: AttentionMetaData, **kwargs):
|
||||
if not attn_metadata.use_spec_dec and not attn_metadata.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.value_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
attn_metadata.k_cache,
|
||||
attn_metadata.v_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.sequence_lengths,
|
||||
)
|
||||
else:
|
||||
else: # else if using speculative decoding
|
||||
if not attn_metadata.use_alibi_attn:
|
||||
rotary_embedding(
|
||||
attn_metadata.query_states,
|
||||
attn_metadata.key_states,
|
||||
kwargs.get('cos', None),
|
||||
kwargs.get('sin', None),
|
||||
kwargs.get("cos", None),
|
||||
kwargs.get("sin", None),
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.key_states,
|
||||
attn_metadata.k_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get('q_len', 1)
|
||||
n=kwargs.get("q_len", 1),
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
attn_metadata.value_states,
|
||||
attn_metadata.v_cache,
|
||||
kv_lengths=attn_metadata.sequence_lengths,
|
||||
block_tables=attn_metadata.block_tables,
|
||||
n=kwargs.get('q_len', 1)
|
||||
n=kwargs.get("q_len", 1),
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype) -> AttentionContext:
|
||||
|
||||
|
||||
def get_pre_attention_backend(
|
||||
use_spec_dec: bool, use_cuda_kernel: bool, dtype: torch.dtype
|
||||
) -> PreAttentionBackend:
|
||||
"""
|
||||
Get the backend for pre-attention computations, including potisional encoding like RoPE and KV cache initialization.
|
||||
"""
|
||||
use_flash_attn = can_use_flash_attn2(dtype)
|
||||
if use_cuda_kernel and use_flash_attn and not use_spec_dec:
|
||||
return CudaAttentionContext()
|
||||
return CudaPreAttentionBackend()
|
||||
else:
|
||||
return TritonAttentionContext()
|
||||
return TritonPreAttentionBackend()
|
|
@ -10,6 +10,8 @@ from torch.distributed import ProcessGroup
|
|||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.utils import get_alibi_slopes
|
||||
from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaMLP
|
||||
from colossalai.inference.modeling.backends.attention_backend import AttentionMetaData, get_attention_backend
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
|
@ -23,28 +25,8 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import Layout, distribute_tensor, is_distributed_tensor
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
use_flash_attn2 = True
|
||||
except ImportError:
|
||||
use_flash_attn2 = False
|
||||
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.")
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
|
@ -251,122 +233,54 @@ class NopadBaichuanAttention(ParallelModule):
|
|||
)
|
||||
|
||||
block_size = k_cache.size(-2)
|
||||
|
||||
if is_prompts:
|
||||
if not is_verifier and use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2:
|
||||
# flash attn 2 currently only supports FP16/BF16.
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision)
|
||||
inference_ops.context_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len
|
||||
)
|
||||
attn_output = flash_attn_varlen_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=kv_seq_len,
|
||||
max_seqlen_k=kv_seq_len,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=sm_scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
attn_output = attn_output.view(token_nums, -1)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
use_new_kcache_layout=use_cuda_kernel,
|
||||
)
|
||||
else:
|
||||
|
||||
attn_metadata = AttentionMetaData(
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
value_states=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
kv_seq_len=kv_seq_len,
|
||||
sequence_lengths=sequence_lengths,
|
||||
sm_scale=sm_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=self.use_alibi_attn,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
high_precision=high_precision,
|
||||
)
|
||||
attn_output = attention_backend.prefill(
|
||||
attn_metadata,
|
||||
token_nums=token_nums,
|
||||
)
|
||||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
if use_cuda_kernel:
|
||||
if not self.use_alibi_attn:
|
||||
inference_ops.rotary_embedding_and_cache_copy(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
high_precision,
|
||||
)
|
||||
else:
|
||||
inference_ops.decode_kv_cache_memcpy(
|
||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
||||
)
|
||||
inference_ops.flash_decoding_attention(
|
||||
output_tensor,
|
||||
query_states,
|
||||
k_cache,
|
||||
v_cache,
|
||||
sequence_lengths,
|
||||
block_tables,
|
||||
block_size,
|
||||
kv_seq_len,
|
||||
fd_inter_tensor.mid_output,
|
||||
fd_inter_tensor.exp_sums,
|
||||
fd_inter_tensor.max_logits,
|
||||
self.alibi_slopes,
|
||||
sm_scale,
|
||||
)
|
||||
attn_output = output_tensor
|
||||
else:
|
||||
if not is_verifier and not self.use_alibi_attn:
|
||||
decoding_fused_rotary_embedding(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
cos_sin[0],
|
||||
cos_sin[1],
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_tables,
|
||||
sequence_lengths,
|
||||
)
|
||||
else:
|
||||
if not self.use_alibi_attn:
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
copy_k_to_blocked_cache(
|
||||
key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
copy_k_to_blocked_cache(
|
||||
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
|
||||
)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sm_scale=sm_scale,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
q_len=q_len,
|
||||
)
|
||||
attn_output = attention_backend.decode(
|
||||
attn_metadata,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
|
|
@ -19,7 +19,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
from colossalai.inference.config import InputMetaData
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.backends.attention_backend import get_attention_backend, AttentionMetaData
|
||||
from colossalai.inference.modeling.backends.attention_context import get_attention_context
|
||||
from colossalai.inference.modeling.backends.pre_attention_backend import get_pre_attention_backend
|
||||
from colossalai.inference.utils import can_use_flash_attn2
|
||||
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||
from colossalai.kernel.triton import (
|
||||
|
@ -121,7 +121,7 @@ def llama_model_forward(
|
|||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata.dtype != torch.float32 and can_use_flash_attn2():
|
||||
if can_use_flash_attn2(inputmetadata.dtype):
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
|
@ -544,13 +544,14 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
output_tensor=output_tensor,
|
||||
use_spec_dec=is_verifier,
|
||||
use_alibi_attn=False,
|
||||
use_cuda_kernel=use_cuda_kernel,
|
||||
)
|
||||
|
||||
attention_backend = get_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
attention_context = get_attention_context(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
pre_attention_backend = get_pre_attention_backend(use_spec_dec=is_verifier, use_cuda_kernel=use_cuda_kernel, dtype=query_states.dtype)
|
||||
|
||||
if is_prompts: # prefilling stage
|
||||
attention_context.prefill(
|
||||
pre_attention_backend.prefill(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
|
@ -563,7 +564,7 @@ class NopadLlamaAttention(LlamaAttention, ParallelModule):
|
|||
else: # decoding stage
|
||||
q_len = tokens_to_verify + 1 if is_verifier else 1
|
||||
|
||||
attention_context.decode(
|
||||
pre_attention_backend.decode(
|
||||
attn_metadata,
|
||||
cos=cos_sin[0],
|
||||
sin=cos_sin[1],
|
||||
|
|
Loading…
Reference in New Issue