[Inference/SpecDec] Support GLIDE Drafter Model (#5455)

* add glide-llama policy and modeling

* update glide modeling, compitable with transformers 4.36.2

* revise glide llama modeling/usage

* fix issues of glimpsing large kv

* revise the way re-loading params for glide drafter

* fix drafter and engine tests

* enable convert to glide strict=False

* revise glide llama modeling

* revise vicuna prompt template

* revise drafter and tests

* apply usage of glide model in engine
feat/speculative-decoding
Yuanheng Zhao 2024-04-01 21:54:24 +08:00 committed by Yuanheng
parent 912e24b2aa
commit d85d91435a
10 changed files with 722 additions and 82 deletions

View File

@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
_DEFAULT_PROMPT_TEMPLATES = {
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
}

View File

@ -12,7 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
from colossalai.inference.config import InferenceConfig, InputMetaData
from colossalai.inference.graph_runner import CUDAGraphRunner
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.spec import Drafter
from colossalai.inference.spec import Drafter, GlideInput
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -72,6 +72,7 @@ class InferenceEngine:
self.use_spec_dec = False
self.drafter_model = None
self.drafter = None
self.use_glide = False
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
if model_policy is None:
@ -229,7 +230,12 @@ class InferenceEngine:
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
def enable_spec_dec(
self,
drafter_model: nn.Module = None,
n_spec_tokens: int = None,
use_glide_drafter: bool = False,
) -> None:
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
Args:
@ -237,6 +243,8 @@ class InferenceEngine:
If provided, the previous drafter and drafter model, if exist, will be overwritten.
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
If True, the drafter model will be replaced by a glide model.
```python
...
@ -269,6 +277,22 @@ class InferenceEngine:
device=self.device,
dtype=self.dtype,
)
# check if the provided drafter model is compatible with GLIDE structure
# when `use_glide_drafter` is set to True
if (
use_glide_drafter
and hasattr(drafter_model, "model")
and hasattr(drafter_model.model, "layers")
and hasattr(drafter_model.model.layers[0], "cross_attn")
):
self.use_glide = use_glide_drafter
elif use_glide_drafter:
self.logger.warning(
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
f"but the provided drafter model is not compatible with GLIDE structure."
f"Falling back to use the default drafter model (non-GLIDE)."
)
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
# using speculative decoding for subsequent generations
self.use_spec_dec = True
@ -278,6 +302,7 @@ class InferenceEngine:
self.request_handler.unset_spec_dec_mode()
# set back to the maximum number of tokens to speculate
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
self.use_glide = False
self.use_spec_dec = False
def clear_spec_dec(self) -> None:
@ -288,6 +313,7 @@ class InferenceEngine:
self.drafter_model = None
self.drafter = None
torch.cuda.empty_cache()
self.use_glide = False
self.use_spec_dec = False
def steps_spec_dec(self) -> List[Sequence]:
@ -304,6 +330,7 @@ class InferenceEngine:
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
drafter_out = self.drafter.speculate(input_ids, 1, None)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
@ -326,7 +353,21 @@ class InferenceEngine:
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
# 3. Decoding - Drafter model speculates `n` tokens
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
glide_input = None
if self.use_glide:
glide_input = GlideInput(
batch.get_block_table_tensor(),
self.k_cahce[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
)
drafter_out = self.drafter.speculate(
input_ids,
self.n_spec_tokens,
drafter_past_key_values,
glide_input=glide_input,
)
next_token_ids_spec = drafter_out.next_tokens
drafter_past_key_values = drafter_out.past_key_values
drafter_spec_length = drafter_out.speculated_length
@ -339,6 +380,8 @@ class InferenceEngine:
already_allocated_kv_len = cur_length
# 4. Decoding - Main model verifies `n` tokens in parallel
if drafter_spec_length < batch.num_tokens_to_verify:
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
logits = self.model(batch, self.k_cahce, self.v_cache)
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
@ -348,6 +391,7 @@ class InferenceEngine:
# revoke appended tokens for each Sequence in the current batch
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
# append the last correct token generated by the main model
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
@ -355,6 +399,7 @@ class InferenceEngine:
drafter_past_key_values = Drafter.trim_kv_cache(
drafter_past_key_values, drafter_spec_length - n_matches - 1
)
# prepare inputs for the next round of speculation
n = 1 if n_matches < drafter_spec_length else 2
input_ids = batch.get_1D_inputs_spec_dec(n)
@ -364,6 +409,11 @@ class InferenceEngine:
if len(finished_sequences) > 0:
break
# Reset back the number of speculated tokens of the batch,
# this is used to handle the last round of speculation, in which case the number of speculated tokens
# by the drafter is less than the number of speculated tokens set to the engine.
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
return finished_sequences
def generate(

View File

@ -0,0 +1,475 @@
# This is modified from huggingface transformers
# https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py
import warnings
from types import MethodType
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaForCausalLM,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
)
from colossalai.inference.spec import GlideInput
from colossalai.kernel.triton import flash_decoding_attention
from colossalai.logging import get_dist_logger
logger = get_dist_logger(__name__)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_single_rotary_pos_emb(q, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
return q_embed
def glide_llama_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
glide_input: Optional[GlideInput] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
if not return_dict:
output = (logits,) + outputs[1:]
return output
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def glide_llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
# embed positions
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
# GlideLlamaDecoderLayer
layer_outputs = decoder_layer(
hidden_states,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GlideLlamaConfig(LlamaConfig):
"""Configuration class with specific arguments used by GLIDE llama model as a drafter"""
def __init__(
self,
large_hidden_size=4096,
large_num_attention_heads=32,
**kwargs,
):
super().__init__(**kwargs)
self.large_hidden_size = large_hidden_size
self.large_num_attention_heads = large_num_attention_heads
class LlamaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GlideLlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# large model (verifier) configs
self.large_hidden_size = config.large_hidden_size
self.large_num_heads = config.large_num_attention_heads
self.large_head_dim = self.large_hidden_size // self.large_num_heads
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.large_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Optional[torch.Tensor]:
bsz, q_len, _ = hidden_states.size()
block_tables = glide_input.block_tables
large_k_cache = glide_input.large_k_cache
large_v_cache = glide_input.large_v_cache
sequence_lengths = glide_input.sequence_lengths
cache_block_size = large_k_cache.size(-2)
query_states = self.q_proj(hidden_states)
kv_seq_len = sequence_lengths.max().item()
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
# for RoPE
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=large_k_cache,
v_cache=large_v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=cache_block_size,
max_seq_len_in_batch=kv_seq_len,
) # attn_output: [bsz * q_len, num_heads * head_dim]
attn_output = attn_output.reshape(bsz, q_len, self.large_hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
# A class to be used to replace LlamaDecoderLayer in a Llama Model as Drafter in speculative decoding.
# Refer to GLIDE with a CAPE https://arxiv.org/pdf/2402.02082.pdf
class GlideLlamaDecoderLayer(nn.Module):
def __init__(self, config: GlideLlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.cross_attn = LlamaCrossAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@staticmethod
def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlamaDecoderLayer":
"""Build a GlideLlamaDecoderLayer from a native LlamaDecoderLayer"""
config: LlamaConfig = module.mlp.config # XXX
layer_idx = module.self_attn.layer_idx
glide_config = GlideLlamaConfig(**config.to_dict())
glide_decoder_layer = GlideLlamaDecoderLayer(glide_config, layer_idx=layer_idx)
return glide_decoder_layer
def forward(
self,
hidden_states: torch.Tensor,
glide_input: GlideInput = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
**kwargs,
)
hidden_states = residual + hidden_states
curr_q_len = hidden_states.size(1)
# Cross attention
if glide_input is None or not glide_input.glimpse_ready:
warnings.warn(
"Data used for glimpsing the past KV caches of the main model (verifier) is not complete. "
"Fall back to normal decoder layer modeling (drafter). "
"This might lead to incorrect results when using the Glide Models for speculative decoding."
)
elif curr_q_len == 1:
# Notice that we skip prefill stage
# always use the output of the main model as the inputs for the next round of speculation
residual = hidden_states
hidden_states = self.cross_attn(
hidden_states=hidden_states,
glide_input=glide_input,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
use_cache=True,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
return outputs
class GlideLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config: GlideLlamaConfig):
super().__init__(config)
self.config = config
bound_method = MethodType(glide_llama_causal_lm_forward, self)
setattr(self, "forward", bound_method)
bound_method = MethodType(glide_llama_model_forward, self.model)
model = getattr(self, "model")
setattr(model, "forward", bound_method)
replaced_layers = nn.ModuleList(
[GlideLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
setattr(model, "layers", replaced_layers)

View File

@ -1,7 +1,9 @@
from .glide_llama import GlideLlamaModelPolicy
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
model_policy_map = {
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
"glide_llama": GlideLlamaModelPolicy,
}
__all__ = ["NoPaddingLlamaModelInferPolicy", "model_polic_map"]
__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]

View File

@ -0,0 +1,45 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from colossalai.inference.modeling.models.glide_llama import (
GlideLlamaDecoderLayer,
glide_llama_causal_lm_forward,
glide_llama_model_forward,
)
from colossalai.inference.utils import init_to_get_rotary
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
def module_policy(self):
policy = super().module_policy()
num_layers = self.model.config.num_hidden_layers
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix=f"layers[{i}]",
target_module=GlideLlamaDecoderLayer,
)
for i in range(num_layers)
],
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_model_forward},
policy=policy,
target_key=LlamaModel,
)
self.append_or_create_method_replacement(
description={"forward": glide_llama_causal_lm_forward},
policy=policy,
target_key=LlamaForCausalLM,
)
return policy
def postprocess(self):
for layer in self.model.model.layers:
init_to_get_rotary(layer.cross_attn)
return self.model

View File

@ -1,4 +1,4 @@
from .drafter import Drafter
from .struct import DrafterOutput
from .struct import DrafterOutput, GlideInput
__all__ = ["Drafter", "DrafterOutput"]
__all__ = ["Drafter", "DrafterOutput", "GlideInput"]

View File

@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
from colossalai.utils import get_current_device
from .struct import DrafterOutput
from .struct import DrafterOutput, GlideInput
class Drafter:
@ -66,6 +66,7 @@ class Drafter:
input_ids: torch.Tensor,
n_spec_tokens: int,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
glide_input: Optional[GlideInput] = None,
) -> DrafterOutput:
"""Generate n_spec_tokens tokens using the drafter model.
@ -73,6 +74,8 @@ class Drafter:
input_ids (torch.Tensor): Input token ids.
n_spec_tokens (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
when using the glide model as a drafter.
"""
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
@ -83,13 +86,16 @@ class Drafter:
logits = []
token_ids = []
kwargs = {"return_dict": True, "use_cache": True}
if glide_input:
# required only when using glide model
kwargs["glide_input"] = glide_input
for _ in range(n_spec_tokens):
outputs = self._drafter_model(
input_ids,
return_dict=True,
use_cache=True,
past_key_values=past_key_values,
)
# update past key values
kwargs["past_key_values"] = past_key_values
outputs = self._drafter_model(input_ids, **kwargs)
next_token_logits = outputs.logits[:, -1, :]
# NOTE Only use greedy search for speculating.
@ -100,12 +106,12 @@ class Drafter:
logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO support bsz > 1
# TODO(yuanheng-zhao) support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values
speculated_length = len(token_ids) # TODO For now, only support bsz 1
speculated_length = len(token_ids) # For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)

View File

@ -27,3 +27,29 @@ class DrafterOutput:
if self.past_key_values is not None:
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
@dataclass
class GlideInput:
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
Used for pack data that will be used during glimpsing KV Caches of the main model.
Args:
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
Blocked key cache of the main model
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
"""
block_tables: torch.Tensor = None
large_k_cache: torch.Tensor = None
large_v_cache: torch.Tensor = None
sequence_lengths: torch.Tensor = None
@property
def glimpse_ready(self):
return all(
attr is not None
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
)

View File

@ -2,18 +2,16 @@ import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
import colossalai
from colossalai.inference.config import GenerationConfig, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.inference.spec.drafter import Drafter
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
NUM_LAYERS = 2
NUM_LAYERS = 1
MAX_LEN = 100
SPEC_NUM = 5
@pytest.mark.parametrize("spec_num", [5])
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
def test_drafter(spec_num: int):
torch.manual_seed(123)
@ -41,68 +39,33 @@ def test_drafter(spec_num: int):
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
def check_sd():
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=MAX_LEN,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_sd()
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_spec_dec():
spawn(run_dist, nprocs=1)
spec_num = SPEC_NUM
device = get_current_device()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.pad_token = tokenizer.eos_token
# Dummy config for Glide Model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=NUM_LAYERS,
)
drafter_model = GlideLlamaForCausalLM(glide_config)
assert hasattr(drafter_model, "model")
assert hasattr(drafter_model.model, "layers")
for _, layer in enumerate(drafter_model.model.layers):
assert hasattr(layer, "cross_attn")
# Init the Drafter by providing the sharded drafter model
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
if __name__ == "__main__":
test_drafter(spec_num=5)
test_drafter(spec_num=SPEC_NUM)
test_spec_dec()

View File

@ -9,6 +9,7 @@ import colossalai
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.inference.flash_decoding_utils import FDIntermTensors
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
FDIntermTensors._instances = {}
@parameterize("num_layers", [1])
@parameterize("max_length", [100])
def check_spec_dec(num_layers, max_length):
torch.manual_seed(123)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
# Dummy configs for testing
toy_config = LlamaConfig(num_hidden_layers=num_layers)
toy_config.pad_token_id = tokenizer.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
large_config = LlamaConfig(
hidden_size=4096,
intermediate_size=11008,
num_attention_heads=32,
num_hidden_layers=8,
num_key_value_heads=32,
max_position_embeddings=2048,
)
large_config.pad_token_id = tokenizer.eos_token_id
main_model = LlamaForCausalLM(large_config)
inference_config = InferenceConfig(
dtype="fp16",
micro_batch_size=1,
max_batch_size=1,
max_input_len=128,
max_output_len=128,
prefill_ratio=1.2,
block_size=16,
)
engine = InferenceEngine(main_model, tokenizer, inference_config)
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
generation_config = GenerationConfig(
pad_token_id=tokenizer.eos_token_id,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.disable_spec_dec()
engine.clear_spec_dec()
assert not engine.use_spec_dec
assert engine.drafter is None and engine.drafter_model is None
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
# test GLIDE model
glide_config = GlideLlamaConfig(
intermediate_size=8192,
large_hidden_size=4096,
large_num_attention_heads=32,
num_hidden_layers=num_layers,
)
glide_model = GlideLlamaForCausalLM(glide_config)
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
out, out_token_ids = engine.generate(
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
)
engine.clear_spec_dec()
assert len(out) == 1
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
check_output_consistency()
check_spec_dec()
@pytest.mark.dist