mirror of https://github.com/hpcaitech/ColossalAI
[Inference/SpecDec] Support GLIDE Drafter Model (#5455)
* add glide-llama policy and modeling * update glide modeling, compitable with transformers 4.36.2 * revise glide llama modeling/usage * fix issues of glimpsing large kv * revise the way re-loading params for glide drafter * fix drafter and engine tests * enable convert to glide strict=False * revise glide llama modeling * revise vicuna prompt template * revise drafter and tests * apply usage of glide model in enginefeat/speculative-decoding
parent
912e24b2aa
commit
d85d91435a
|
@ -26,7 +26,7 @@ _ALLOWED_DTYPES = [torch.float16, torch.bfloat16, torch.float32]
|
|||
|
||||
_DEFAULT_PROMPT_TEMPLATES = {
|
||||
"llama": "[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n{input_text}[/INST]",
|
||||
"vicuna": "USER: {input_text}\n\nASSISTANT: ",
|
||||
"vicuna": "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user input. USER: {input_text}\nASSISTANT: ",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.inference.batch_bucket import BatchBucket
|
|||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
from colossalai.inference.graph_runner import CUDAGraphRunner
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.spec import Drafter
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
@ -72,6 +72,7 @@ class InferenceEngine:
|
|||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
if model_policy is None:
|
||||
|
@ -229,7 +230,12 @@ class InferenceEngine:
|
|||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model
|
||||
|
||||
def enable_spec_dec(self, drafter_model: nn.Module = None, n_spec_tokens: int = None) -> None:
|
||||
def enable_spec_dec(
|
||||
self,
|
||||
drafter_model: nn.Module = None,
|
||||
n_spec_tokens: int = None,
|
||||
use_glide_drafter: bool = False,
|
||||
) -> None:
|
||||
"""Initialize drafter (if it has not yet), and enable Speculative Decoding for subsequent generations.
|
||||
|
||||
Args:
|
||||
|
@ -237,6 +243,8 @@ class InferenceEngine:
|
|||
If provided, the previous drafter and drafter model, if exist, will be overwritten.
|
||||
n_spec_tokens (Optional[int]): The number of tokens to speculate in each round of speculating-verifying.
|
||||
If not provided, `max_n_spec_tokens` in InferenceConfig will be used.
|
||||
use_glide_drafter (bool): Whether to use glide model for speculative decoding. Defaults to False.
|
||||
If True, the drafter model will be replaced by a glide model.
|
||||
|
||||
```python
|
||||
...
|
||||
|
@ -269,6 +277,22 @@ class InferenceEngine:
|
|||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# check if the provided drafter model is compatible with GLIDE structure
|
||||
# when `use_glide_drafter` is set to True
|
||||
if (
|
||||
use_glide_drafter
|
||||
and hasattr(drafter_model, "model")
|
||||
and hasattr(drafter_model.model, "layers")
|
||||
and hasattr(drafter_model.model.layers[0], "cross_attn")
|
||||
):
|
||||
self.use_glide = use_glide_drafter
|
||||
elif use_glide_drafter:
|
||||
self.logger.warning(
|
||||
f"`use_glide_drafter` is provided as {use_glide_drafter}, "
|
||||
f"but the provided drafter model is not compatible with GLIDE structure."
|
||||
f"Falling back to use the default drafter model (non-GLIDE)."
|
||||
)
|
||||
self.request_handler.set_spec_dec_mode(self.n_spec_tokens)
|
||||
# using speculative decoding for subsequent generations
|
||||
self.use_spec_dec = True
|
||||
|
@ -278,6 +302,7 @@ class InferenceEngine:
|
|||
self.request_handler.unset_spec_dec_mode()
|
||||
# set back to the maximum number of tokens to speculate
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def clear_spec_dec(self) -> None:
|
||||
|
@ -288,6 +313,7 @@ class InferenceEngine:
|
|||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
torch.cuda.empty_cache()
|
||||
self.use_glide = False
|
||||
self.use_spec_dec = False
|
||||
|
||||
def steps_spec_dec(self) -> List[Sequence]:
|
||||
|
@ -304,6 +330,7 @@ class InferenceEngine:
|
|||
input_ids = batch.get_1D_inputs() # bsz 1 for drafter model
|
||||
|
||||
# 1. Prefill small model (Drafter) - fill past kv cache for drafter model
|
||||
# NOTE For glide drafter models, we won't actually apply glide during prefill stage
|
||||
drafter_out = self.drafter.speculate(input_ids, 1, None)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
|
@ -326,7 +353,21 @@ class InferenceEngine:
|
|||
assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
|
||||
|
||||
# 3. Decoding - Drafter model speculates `n` tokens
|
||||
drafter_out = self.drafter.speculate(input_ids, self.n_spec_tokens, drafter_past_key_values)
|
||||
glide_input = None
|
||||
if self.use_glide:
|
||||
glide_input = GlideInput(
|
||||
batch.get_block_table_tensor(),
|
||||
self.k_cahce[-1], # use kv cahces of the last layer
|
||||
self.v_cache[-1],
|
||||
batch.get_sequence_lengths(),
|
||||
)
|
||||
|
||||
drafter_out = self.drafter.speculate(
|
||||
input_ids,
|
||||
self.n_spec_tokens,
|
||||
drafter_past_key_values,
|
||||
glide_input=glide_input,
|
||||
)
|
||||
next_token_ids_spec = drafter_out.next_tokens
|
||||
drafter_past_key_values = drafter_out.past_key_values
|
||||
drafter_spec_length = drafter_out.speculated_length
|
||||
|
@ -339,6 +380,8 @@ class InferenceEngine:
|
|||
already_allocated_kv_len = cur_length
|
||||
|
||||
# 4. Decoding - Main model verifies `n` tokens in parallel
|
||||
if drafter_spec_length < batch.num_tokens_to_verify:
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=drafter_spec_length)
|
||||
logits = self.model(batch, self.k_cahce, self.v_cache)
|
||||
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
|
||||
|
||||
|
@ -348,6 +391,7 @@ class InferenceEngine:
|
|||
|
||||
# revoke appended tokens for each Sequence in the current batch
|
||||
batch.revoke_batch_tokens(drafter_spec_length - n_matches) # revoke drafted tokens
|
||||
|
||||
# append the last correct token generated by the main model
|
||||
self.request_handler.append_next_tokens(next_tokens[n_matches].unsqueeze(0))
|
||||
|
||||
|
@ -355,6 +399,7 @@ class InferenceEngine:
|
|||
drafter_past_key_values = Drafter.trim_kv_cache(
|
||||
drafter_past_key_values, drafter_spec_length - n_matches - 1
|
||||
)
|
||||
|
||||
# prepare inputs for the next round of speculation
|
||||
n = 1 if n_matches < drafter_spec_length else 2
|
||||
input_ids = batch.get_1D_inputs_spec_dec(n)
|
||||
|
@ -364,6 +409,11 @@ class InferenceEngine:
|
|||
if len(finished_sequences) > 0:
|
||||
break
|
||||
|
||||
# Reset back the number of speculated tokens of the batch,
|
||||
# this is used to handle the last round of speculation, in which case the number of speculated tokens
|
||||
# by the drafter is less than the number of speculated tokens set to the engine.
|
||||
batch.set_use_spec_dec(num_tokens_to_verify=self.n_spec_tokens)
|
||||
|
||||
return finished_sequences
|
||||
|
||||
def generate(
|
||||
|
|
|
@ -0,0 +1,475 @@
|
|||
# This is modified from huggingface transformers
|
||||
# https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/models/llama/modeling_llama.py
|
||||
import warnings
|
||||
from types import MethodType
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
LlamaDecoderLayer,
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
LlamaForCausalLM,
|
||||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaMLP,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from colossalai.inference.spec import GlideInput
|
||||
from colossalai.kernel.triton import flash_decoding_attention
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_single_rotary_pos_emb(q, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
return q_embed
|
||||
|
||||
|
||||
def glide_llama_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
glide_input: Optional[GlideInput] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||
```"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
glide_input=glide_input,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
logits = logits.float()
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=None,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def glide_llama_model_forward(
|
||||
self: LlamaModel,
|
||||
input_ids: torch.LongTensor = None,
|
||||
glide_input: GlideInput = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length = inputs_embeds.shape[:2]
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
past_key_values_length = 0
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
elif self._use_sdpa and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
||||
)
|
||||
|
||||
# embed positions
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
# GlideLlamaDecoderLayer
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
glide_input=glide_input,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class GlideLlamaConfig(LlamaConfig):
|
||||
"""Configuration class with specific arguments used by GLIDE llama model as a drafter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
large_hidden_size=4096,
|
||||
large_num_attention_heads=32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.large_hidden_size = large_hidden_size
|
||||
self.large_num_attention_heads = large_num_attention_heads
|
||||
|
||||
|
||||
class LlamaCrossAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GlideLlamaConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
|
||||
# large model (verifier) configs
|
||||
self.large_hidden_size = config.large_hidden_size
|
||||
self.large_num_heads = config.large_num_attention_heads
|
||||
self.large_head_dim = self.large_hidden_size // self.large_num_heads
|
||||
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.large_num_heads * self.large_head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(self.large_num_heads * self.large_head_dim, self.hidden_size, bias=False)
|
||||
self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = LlamaRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
)
|
||||
else:
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
||||
self.large_head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
scaling_factor=scaling_factor,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
glide_input: GlideInput = None, # Used for glimpsing main model's KV caches
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
block_tables = glide_input.block_tables
|
||||
large_k_cache = glide_input.large_k_cache
|
||||
large_v_cache = glide_input.large_v_cache
|
||||
sequence_lengths = glide_input.sequence_lengths
|
||||
cache_block_size = large_k_cache.size(-2)
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
kv_seq_len = sequence_lengths.max().item()
|
||||
|
||||
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)
|
||||
|
||||
# for RoPE
|
||||
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
|
||||
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
|
||||
query_states = query_states.transpose(1, 2)
|
||||
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
|
||||
|
||||
attn_output = flash_decoding_attention(
|
||||
q=query_states,
|
||||
k_cache=large_k_cache,
|
||||
v_cache=large_v_cache,
|
||||
kv_seq_len=sequence_lengths,
|
||||
block_tables=block_tables,
|
||||
block_size=cache_block_size,
|
||||
max_seq_len_in_batch=kv_seq_len,
|
||||
) # attn_output: [bsz * q_len, num_heads * head_dim]
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.large_hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
# A class to be used to replace LlamaDecoderLayer in a Llama Model as Drafter in speculative decoding.
|
||||
# Refer to GLIDE with a CAPE https://arxiv.org/pdf/2402.02082.pdf
|
||||
class GlideLlamaDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GlideLlamaConfig, layer_idx: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||
self.cross_attn = LlamaCrossAttention(config=config)
|
||||
self.mlp = LlamaMLP(config)
|
||||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaDecoderLayer, *args, **kwargs) -> "GlideLlamaDecoderLayer":
|
||||
"""Build a GlideLlamaDecoderLayer from a native LlamaDecoderLayer"""
|
||||
config: LlamaConfig = module.mlp.config # XXX
|
||||
layer_idx = module.self_attn.layer_idx
|
||||
glide_config = GlideLlamaConfig(**config.to_dict())
|
||||
glide_decoder_layer = GlideLlamaDecoderLayer(glide_config, layer_idx=layer_idx)
|
||||
|
||||
return glide_decoder_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
glide_input: GlideInput = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
||||
query_sequence_length, key_sequence_length)` if default attention is used.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
if "padding_mask" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
curr_q_len = hidden_states.size(1)
|
||||
# Cross attention
|
||||
if glide_input is None or not glide_input.glimpse_ready:
|
||||
warnings.warn(
|
||||
"Data used for glimpsing the past KV caches of the main model (verifier) is not complete. "
|
||||
"Fall back to normal decoder layer modeling (drafter). "
|
||||
"This might lead to incorrect results when using the Glide Models for speculative decoding."
|
||||
)
|
||||
elif curr_q_len == 1:
|
||||
# Notice that we skip prefill stage
|
||||
# always use the output of the main model as the inputs for the next round of speculation
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.cross_attn(
|
||||
hidden_states=hidden_states,
|
||||
glide_input=glide_input,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=True,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class GlideLlamaForCausalLM(LlamaForCausalLM):
|
||||
def __init__(self, config: GlideLlamaConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
bound_method = MethodType(glide_llama_causal_lm_forward, self)
|
||||
setattr(self, "forward", bound_method)
|
||||
bound_method = MethodType(glide_llama_model_forward, self.model)
|
||||
model = getattr(self, "model")
|
||||
setattr(model, "forward", bound_method)
|
||||
replaced_layers = nn.ModuleList(
|
||||
[GlideLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
setattr(model, "layers", replaced_layers)
|
|
@ -1,7 +1,9 @@
|
|||
from .glide_llama import GlideLlamaModelPolicy
|
||||
from .nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
|
||||
model_policy_map = {
|
||||
"nopadding_llama": NoPaddingLlamaModelInferPolicy,
|
||||
"glide_llama": GlideLlamaModelPolicy,
|
||||
}
|
||||
|
||||
__all__ = ["NoPaddingLlamaModelInferPolicy", "model_polic_map"]
|
||||
__all__ = ["NoPaddingLlamaModelInferPolicy", "GlideLlamaModelPolicy", "model_polic_map"]
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from colossalai.inference.modeling.models.glide_llama import (
|
||||
GlideLlamaDecoderLayer,
|
||||
glide_llama_causal_lm_forward,
|
||||
glide_llama_model_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.policies.base_policy import SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
||||
class GlideLlamaModelPolicy(LlamaForCausalLMPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
num_layers = self.model.config.num_hidden_layers
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix=f"layers[{i}]",
|
||||
target_module=GlideLlamaDecoderLayer,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": glide_llama_model_forward},
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
)
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": glide_llama_causal_lm_forward},
|
||||
policy=policy,
|
||||
target_key=LlamaForCausalLM,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
for layer in self.model.model.layers:
|
||||
init_to_get_rotary(layer.cross_attn)
|
||||
return self.model
|
|
@ -1,4 +1,4 @@
|
|||
from .drafter import Drafter
|
||||
from .struct import DrafterOutput
|
||||
from .struct import DrafterOutput, GlideInput
|
||||
|
||||
__all__ = ["Drafter", "DrafterOutput"]
|
||||
__all__ = ["Drafter", "DrafterOutput", "GlideInput"]
|
||||
|
|
|
@ -6,7 +6,7 @@ from transformers import PreTrainedTokenizer
|
|||
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .struct import DrafterOutput
|
||||
from .struct import DrafterOutput, GlideInput
|
||||
|
||||
|
||||
class Drafter:
|
||||
|
@ -66,6 +66,7 @@ class Drafter:
|
|||
input_ids: torch.Tensor,
|
||||
n_spec_tokens: int,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
glide_input: Optional[GlideInput] = None,
|
||||
) -> DrafterOutput:
|
||||
"""Generate n_spec_tokens tokens using the drafter model.
|
||||
|
||||
|
@ -73,6 +74,8 @@ class Drafter:
|
|||
input_ids (torch.Tensor): Input token ids.
|
||||
n_spec_tokens (int): Number of tokens to speculate.
|
||||
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
|
||||
glide_input (Optional[GlideInput]): The packed input for glimpsing kv caches of the main model,
|
||||
when using the glide model as a drafter.
|
||||
"""
|
||||
assert n_spec_tokens >= 1, f"Invalid number {n_spec_tokens} to speculate"
|
||||
|
||||
|
@ -83,13 +86,16 @@ class Drafter:
|
|||
logits = []
|
||||
token_ids = []
|
||||
|
||||
kwargs = {"return_dict": True, "use_cache": True}
|
||||
if glide_input:
|
||||
# required only when using glide model
|
||||
kwargs["glide_input"] = glide_input
|
||||
|
||||
for _ in range(n_spec_tokens):
|
||||
outputs = self._drafter_model(
|
||||
input_ids,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
# update past key values
|
||||
kwargs["past_key_values"] = past_key_values
|
||||
|
||||
outputs = self._drafter_model(input_ids, **kwargs)
|
||||
next_token_logits = outputs.logits[:, -1, :]
|
||||
|
||||
# NOTE Only use greedy search for speculating.
|
||||
|
@ -100,12 +106,12 @@ class Drafter:
|
|||
logits.append(next_token_logits)
|
||||
token_ids.append(next_token_ids)
|
||||
if next_token_ids.item() == self._tokenizer.eos_token_id:
|
||||
# TODO support bsz > 1
|
||||
# TODO(yuanheng-zhao) support bsz > 1
|
||||
break
|
||||
input_ids = next_token_ids[:, None]
|
||||
past_key_values = outputs.past_key_values
|
||||
|
||||
speculated_length = len(token_ids) # TODO For now, only support bsz 1
|
||||
speculated_length = len(token_ids) # For now, only support bsz 1
|
||||
logits = torch.concat(logits, dim=0)
|
||||
token_ids = torch.concat(token_ids, dim=-1)
|
||||
|
||||
|
|
|
@ -27,3 +27,29 @@ class DrafterOutput:
|
|||
if self.past_key_values is not None:
|
||||
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
||||
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlideInput:
|
||||
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
|
||||
Used for pack data that will be used during glimpsing KV Caches of the main model.
|
||||
|
||||
Args:
|
||||
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
|
||||
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
|
||||
Blocked key cache of the main model
|
||||
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
|
||||
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
|
||||
"""
|
||||
|
||||
block_tables: torch.Tensor = None
|
||||
large_k_cache: torch.Tensor = None
|
||||
large_v_cache: torch.Tensor = None
|
||||
sequence_lengths: torch.Tensor = None
|
||||
|
||||
@property
|
||||
def glimpse_ready(self):
|
||||
return all(
|
||||
attr is not None
|
||||
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
|
||||
)
|
||||
|
|
|
@ -2,18 +2,16 @@ import pytest
|
|||
import torch
|
||||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import GenerationConfig, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||
from colossalai.inference.spec.drafter import Drafter
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
NUM_LAYERS = 2
|
||||
NUM_LAYERS = 1
|
||||
MAX_LEN = 100
|
||||
SPEC_NUM = 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("spec_num", [5])
|
||||
@pytest.mark.parametrize("spec_num", [SPEC_NUM])
|
||||
def test_drafter(spec_num: int):
|
||||
torch.manual_seed(123)
|
||||
|
||||
|
@ -41,68 +39,33 @@ def test_drafter(spec_num: int):
|
|||
assert trimmed_past_key_values[0][0].size(2) == past_kv_length - reject_num
|
||||
|
||||
|
||||
def check_sd():
|
||||
torch.manual_seed(123)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# Dummy configs for testing
|
||||
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
drafter_model = drafter_model.eval().cuda()
|
||||
large_config = LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=8,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
large_config.pad_token_id = tokenizer.eos_token_id
|
||||
main_model = LlamaForCausalLM(large_config)
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype="fp16",
|
||||
micro_batch_size=1,
|
||||
max_batch_size=1,
|
||||
max_input_len=128,
|
||||
max_output_len=128,
|
||||
prefill_ratio=1.2,
|
||||
block_size=16,
|
||||
)
|
||||
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
|
||||
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_length=MAX_LEN,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.disable_spec_dec()
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert not engine.use_spec_dec
|
||||
assert engine.drafter is None and engine.drafter_model is None
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == MAX_LEN
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_sd()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_spec_dec():
|
||||
spawn(run_dist, nprocs=1)
|
||||
spec_num = SPEC_NUM
|
||||
device = get_current_device()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Dummy config for Glide Model
|
||||
glide_config = GlideLlamaConfig(
|
||||
intermediate_size=8192,
|
||||
large_hidden_size=4096,
|
||||
large_num_attention_heads=32,
|
||||
num_hidden_layers=NUM_LAYERS,
|
||||
)
|
||||
drafter_model = GlideLlamaForCausalLM(glide_config)
|
||||
|
||||
assert hasattr(drafter_model, "model")
|
||||
assert hasattr(drafter_model.model, "layers")
|
||||
for _, layer in enumerate(drafter_model.model.layers):
|
||||
assert hasattr(layer, "cross_attn")
|
||||
|
||||
# Init the Drafter by providing the sharded drafter model
|
||||
drafter = Drafter(drafter_model, tokenizer, device=device, dtype=torch.float16)
|
||||
|
||||
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
|
||||
out = drafter.speculate(input_ids, spec_num, past_key_values=None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_drafter(spec_num=5)
|
||||
test_drafter(spec_num=SPEC_NUM)
|
||||
test_spec_dec()
|
||||
|
|
|
@ -9,6 +9,7 @@ import colossalai
|
|||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -80,9 +81,81 @@ def check_output_consistency(prompt_template):
|
|||
FDIntermTensors._instances = {}
|
||||
|
||||
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [100])
|
||||
def check_spec_dec(num_layers, max_length):
|
||||
torch.manual_seed(123)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# Dummy configs for testing
|
||||
toy_config = LlamaConfig(num_hidden_layers=num_layers)
|
||||
toy_config.pad_token_id = tokenizer.eos_token_id
|
||||
drafter_model = LlamaForCausalLM(toy_config)
|
||||
drafter_model = drafter_model.eval().cuda()
|
||||
large_config = LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=8,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
)
|
||||
large_config.pad_token_id = tokenizer.eos_token_id
|
||||
main_model = LlamaForCausalLM(large_config)
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype="fp16",
|
||||
micro_batch_size=1,
|
||||
max_batch_size=1,
|
||||
max_input_len=128,
|
||||
max_output_len=128,
|
||||
prefill_ratio=1.2,
|
||||
block_size=16,
|
||||
)
|
||||
engine = InferenceEngine(main_model, tokenizer, inference_config)
|
||||
engine.enable_spec_dec(drafter_model, n_spec_tokens=5)
|
||||
|
||||
dummy_inputs = torch.randint(low=5, high=1000, size=(1, 10), dtype=torch.long, device="cuda")
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_length=max_length,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.disable_spec_dec()
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert not engine.use_spec_dec
|
||||
assert engine.drafter is None and engine.drafter_model is None
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
|
||||
# test GLIDE model
|
||||
glide_config = GlideLlamaConfig(
|
||||
intermediate_size=8192,
|
||||
large_hidden_size=4096,
|
||||
large_num_attention_heads=32,
|
||||
num_hidden_layers=num_layers,
|
||||
)
|
||||
glide_model = GlideLlamaForCausalLM(glide_config)
|
||||
engine.enable_spec_dec(glide_model, use_glide_drafter=True)
|
||||
|
||||
out, out_token_ids = engine.generate(
|
||||
prompts_token_ids=dummy_inputs, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
engine.clear_spec_dec()
|
||||
|
||||
assert len(out) == 1
|
||||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency()
|
||||
check_spec_dec()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue