mirror of https://github.com/hpcaitech/ColossalAI
[Inference]Add Nopadding Llama Modeling (#5327)
* add nopadding llama modeling * add nopadding_llama.py * rm unused codes * fix bugs in test_xine_copy.py * fix code stylepull/5332/head
parent
c7c104cb7c
commit
e8f0642f28
|
@ -32,6 +32,7 @@ class InferenceConfig:
|
|||
During generation, the beam width provided as sampling parameter should be less than or equivalent to this value.
|
||||
prefill_ratio (Optional[float]): A controling ratio for prefill and decoding in running list, we will do a step of prefill
|
||||
when the actual value exceeds this ratio.
|
||||
pad_input: Whether to pad all inputs to the max length.
|
||||
quant_mode (Optional[str]): Quantization mode.
|
||||
revision (Optional[str]): The specific version(a branch, name, a commit id, or a tag name) of model to use.
|
||||
"""
|
||||
|
@ -49,6 +50,7 @@ class InferenceConfig:
|
|||
beam_width: int = 1
|
||||
# the ratio of prefill sequences to decoding sequences, we do prefill step once the actual value exceeds ratio
|
||||
prefill_ratio: Optional[float] = 1.2
|
||||
pad_input: bool = False
|
||||
quant_mode: Optional[str] = None
|
||||
revision: Optional[str] = None
|
||||
|
||||
|
|
|
@ -57,7 +57,11 @@ class InferenceEngine:
|
|||
model.to(self.dtype)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
|
@ -168,7 +172,9 @@ class InferenceEngine:
|
|||
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"]
|
||||
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
if isinstance(prompts_token_ids, list):
|
||||
pass
|
||||
|
@ -237,7 +243,9 @@ class InferenceEngine:
|
|||
self.v_cache,
|
||||
)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
if self.inference_config.pad_input:
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
self.request_handler.search_tokens(self.generation_config, logits)
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
)
|
||||
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.struct import BatchInfo
|
||||
from colossalai.kernel.triton import (
|
||||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
try:
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
hidden_states = llama_model_forward(
|
||||
self.model,
|
||||
batch=batch,
|
||||
k_caches=k_caches,
|
||||
v_caches=v_caches,
|
||||
)
|
||||
logits = torch.mm(hidden_states, self.lm_head.weight.transpose(0, 1))
|
||||
return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
batch: BatchInfo = None,
|
||||
k_caches: List[torch.Tensor] = None,
|
||||
v_caches: List[torch.Tensor] = None,
|
||||
):
|
||||
input_ids = batch.get_1D_inputs()
|
||||
block_tables = batch.get_block_table_tensor()
|
||||
|
||||
sequence_lengths = batch.get_sequence_lengths()
|
||||
batch_size = len(sequence_lengths)
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
||||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
(sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
else:
|
||||
output_tensor = torch.zeros(
|
||||
(batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device
|
||||
)
|
||||
sm_scale = 1.0 / (batch.head_dim**0.5)
|
||||
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_caches[layer_id],
|
||||
v_cache=v_caches[layer_id],
|
||||
is_prompts=batch.is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=batch.fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
if batch.is_prompts:
|
||||
last_token_indexs = sequence_lengths.cumsum(dim=-1)
|
||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def llama_decoder_layer_forward(
|
||||
self: LlamaDecoderLayer,
|
||||
hidden_states: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
block_tables=block_tables,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
is_prompts=is_prompts,
|
||||
sequence_lengths=sequence_lengths,
|
||||
kv_seq_len=kv_seq_len,
|
||||
cos_sin=cos_sin,
|
||||
fd_inter_tensor=fd_inter_tensor,
|
||||
output_tensor=output_tensor,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Replace transformers.models.llama.modeling_llama.LlamaAttention.forward
|
||||
@torch.no_grad()
|
||||
def llama_attn_forward(
|
||||
self: LlamaAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
block_tables: torch.Tensor = None,
|
||||
k_cache: torch.Tensor = None,
|
||||
v_cache: torch.Tensor = None,
|
||||
is_prompts: bool = True,
|
||||
sequence_lengths: torch.Tensor = None,
|
||||
kv_seq_len: int = 0,
|
||||
cos_sin: Tuple[torch.Tensor] = None,
|
||||
fd_inter_tensor: FDIntermTensors = None,
|
||||
output_tensor: torch.Tensor = None,
|
||||
sm_scale: int = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
query_states = torch.mm(hidden_states, self.q_proj.weight.transpose(0, 1)).view(-1, self.num_heads, self.head_dim)
|
||||
key_states = torch.mm(hidden_states, self.k_proj.weight.transpose(0, 1)).view(
|
||||
-1, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
value_states = torch.mm(hidden_states, self.v_proj.weight.transpose(0, 1)).view(
|
||||
-1, self.num_key_value_heads, self.head_dim
|
||||
)
|
||||
|
||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||
|
||||
_, _, _, block_size = k_cache.shape
|
||||
|
||||
if is_prompts:
|
||||
attn_output = context_attention_unpadded(
|
||||
q=query_states,
|
||||
k=key_states,
|
||||
v=value_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
context_lengths=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
output=output_tensor,
|
||||
max_seq_len=kv_seq_len,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
else:
|
||||
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=k_cache,
|
||||
v_cache=v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
output=output_tensor,
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
)
|
||||
attn_output = attn_output.squeeze(1)
|
||||
|
||||
attn_output = attn_output.view(-1, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj.weight.transpose(0, 1))
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def nopad_mlp(self: LlamaMLP, hidden_states: torch.Tensor):
|
||||
gate_proj_out = torch.mm(hidden_states, self.gate_proj.weight.transpose(0, 1))
|
||||
act_out = torch.nn.functional.silu(gate_proj_out, inplace=True)
|
||||
up_proj_out = torch.mm(hidden_states, self.up_proj.weight.transpose(0, 1))
|
||||
tmp_out = act_out * up_proj_out
|
||||
return torch.mm(tmp_out, self.down_proj.weight.transpose(0, 1))
|
|
@ -11,6 +11,7 @@ from colossalai.kernel.triton import (
|
|||
context_attention_unpadded,
|
||||
copy_kv_to_blocked_cache,
|
||||
flash_decoding_attention,
|
||||
get_xine_cache,
|
||||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -101,12 +102,7 @@ def llama_model_forward(
|
|||
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# When testing, the performance of get_xine_cache is lower than that of get_cos_sin.
|
||||
# cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts)
|
||||
# sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts)
|
||||
# cos_sin = (cos, sin)
|
||||
|
||||
cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype)
|
||||
cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts)
|
||||
|
||||
if batch.is_prompts:
|
||||
output_tensor = torch.zeros(
|
||||
|
@ -135,7 +131,9 @@ def llama_model_forward(
|
|||
sm_scale=sm_scale,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, -1, :].unsqueeze(dim=1).contiguous()
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -327,26 +325,3 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_
|
|||
k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices)
|
||||
return (q, k, v, indices)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
||||
"""
|
||||
Get cos and sin for the cache, and return nopad format.
|
||||
Args:
|
||||
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
||||
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
|
||||
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
|
||||
is_prompts: bool, mark if in prefill mode.
|
||||
dtype: The data type of this inference process.
|
||||
"""
|
||||
|
||||
if is_prompts:
|
||||
index_arrays = [torch.arange(length) for length in lengths]
|
||||
else:
|
||||
index_arrays = [(length - 1).view(-1) for length in lengths]
|
||||
indices = torch.cat(index_arrays, dim=-1)
|
||||
cos_output = cos_cache[indices].to(dtype=dtype)
|
||||
sin_output = sin_cache[indices].to(dtype=dtype)
|
||||
|
||||
return (cos_output, sin_output)
|
|
@ -1,7 +1,9 @@
|
|||
from .llama import LlamaModelInferPolicy
|
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
from .padding_llama import PaddingLlamaModelInferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"llama": LlamaModelInferPolicy,
|
||||
"padding_llama": PaddingLlamaModelInferPolicy,
|
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||
}
|
||||
|
||||
__all__ = ["LlamaModelInferPolicy", "model_polic_map"]
|
||||
__all__ = ["PaddingLlamaModelInferPolicy", "NoPaddingLlamaModelInferPolicy", "model_polic_map"]
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
llama_attn_forward,
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
llama_model_forward,
|
||||
nopad_mlp,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
|
||||
# import colossalai
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import rms_layernorm
|
||||
|
||||
HAS_TRITON_RMSNORM = True
|
||||
except:
|
||||
print("you should install triton from https://github.com/openai/triton")
|
||||
HAS_TRITON_RMSNORM = False
|
||||
|
||||
|
||||
def get_triton_rmsnorm_forward():
|
||||
if HAS_TRITON_RMSNORM:
|
||||
|
||||
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
|
||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon)
|
||||
|
||||
return _triton_rmsnorm_forward
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
self.shard_config._infer()
|
||||
|
||||
infer_forward = llama_causal_lm_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaForCausalLM
|
||||
)
|
||||
|
||||
infer_forward = llama_model_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
|
||||
|
||||
infer_forward = llama_decoder_layer_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||
)
|
||||
|
||||
infer_forward = nopad_mlp
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaMLP)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaAttention
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaFlashAttention2
|
||||
)
|
||||
|
||||
infer_forward = llama_attn_forward
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaSdpaAttention
|
||||
)
|
||||
|
||||
infer_forward = None
|
||||
if HAS_TRITON_RMSNORM:
|
||||
infer_forward = get_triton_rmsnorm_forward()
|
||||
|
||||
if infer_forward is not None:
|
||||
method_replacement = {"forward": partial(infer_forward)}
|
||||
self.append_or_create_method_replacement(
|
||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
init_to_get_rotary(self.model.model)
|
||||
return self.model
|
|
@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import (
|
|||
LlamaSdpaAttention,
|
||||
)
|
||||
|
||||
from colossalai.inference.modeling.models.llama import (
|
||||
from colossalai.inference.modeling.models.padding_llama import (
|
||||
llama_attn_forward,
|
||||
llama_causal_lm_forward,
|
||||
llama_decoder_layer_forward,
|
||||
|
@ -43,7 +43,7 @@ def get_triton_rmsnorm_forward():
|
|||
return None
|
||||
|
||||
|
||||
class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
class PaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -358,21 +358,16 @@ class BatchInfo:
|
|||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
input_len_list.append(seq.sentence_len)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
input_len_list.append(1)
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||
input_len_list, dtype=torch.int, device=self.device
|
||||
)
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
|
@ -401,7 +396,9 @@ class BatchInfo:
|
|||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
|
||||
attn_mask = _make_tensor_with_pad(
|
||||
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
return attn_mask.ne(padding_id).long()
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ import pytest
|
|||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.inference.modeling.models.llama import get_cos_sin
|
||||
from colossalai.kernel.triton import get_xine_cache
|
||||
|
||||
try:
|
||||
|
@ -16,6 +15,29 @@ except ImportError:
|
|||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype):
|
||||
"""
|
||||
Get cos and sin for the cache, and return nopad format.
|
||||
Args:
|
||||
lengths: shape(num_seqs,), stores lenghth of each sequence.
|
||||
cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model.
|
||||
sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model.
|
||||
is_prompts: bool, mark if in prefill mode.
|
||||
dtype: The data type of this inference process.
|
||||
"""
|
||||
|
||||
if is_prompts:
|
||||
index_arrays = [torch.arange(length) for length in lengths]
|
||||
else:
|
||||
index_arrays = [(length - 1).view(-1) for length in lengths]
|
||||
indices = torch.cat(index_arrays, dim=-1)
|
||||
cos_output = cos_cache[indices].to(dtype=dtype)
|
||||
sin_output = sin_cache[indices].to(dtype=dtype)
|
||||
|
||||
return (cos_output, sin_output)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||
@pytest.mark.parametrize("MAX_SEQ_LEN", [64])
|
||||
@pytest.mark.parametrize("HEAD_DIM", [64])
|
||||
|
@ -23,15 +45,18 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
|||
def test_get_xine_cache(BATCH_SIZE, MAX_SEQ_LEN, HEAD_DIM, dtype):
|
||||
MAX_TOTAL_TOKENS = BATCH_SIZE * MAX_SEQ_LEN
|
||||
cos_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
||||
sin_cache = torch.randn((MAX_TOTAL_TOKENS, HEAD_DIM), dtype=dtype, device="cuda")
|
||||
lengths = torch.randint(2, MAX_SEQ_LEN, (BATCH_SIZE,), device="cuda")
|
||||
# prefill
|
||||
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=True, dtype=dtype)
|
||||
cos = get_xine_cache(lengths, cos_cache, is_prompts=True)
|
||||
cos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
||||
assert torch.allclose(cos, cos_ref)
|
||||
assert torch.allclose(sin, sin_ref)
|
||||
# decoding
|
||||
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, cos_cache, is_prompts=False, dtype=dtype)
|
||||
cos = get_xine_cache(lengths, cos_cache, is_prompts=False)
|
||||
ncos_ref, sin_ref = get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=False, dtype=dtype)
|
||||
cos, sin = get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=False)
|
||||
assert torch.allclose(cos, ncos_ref)
|
||||
assert torch.allclose(sin, sin_ref)
|
||||
|
||||
|
||||
configs = [
|
||||
|
|
Loading…
Reference in New Issue