mirror of https://github.com/hpcaitech/ColossalAI
57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
|
|
@dataclass
|
|
class DrafterOutput:
|
|
"""
|
|
Dataclass for drafter model outputs.
|
|
|
|
Args:
|
|
speculated_length (int): Speculated length of the output sequence
|
|
It is always less than or equal to spec_num during drafter's speculation process
|
|
logits (torch.FloatTensor): Logits of the output sequence
|
|
next_tokens (torch.Tensor): Next token ids
|
|
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
|
|
"""
|
|
|
|
speculated_length: int = None
|
|
logits: torch.FloatTensor = None
|
|
next_tokens: torch.Tensor = None
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
|
|
def __post_init__(self):
|
|
assert self.speculated_length is not None and self.speculated_length >= 0
|
|
if self.past_key_values is not None:
|
|
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
|
|
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
|
|
|
|
|
|
@dataclass
|
|
class GlideInput:
|
|
"""Dataclass for Glide Models (e.g. `colossalai/inference/modeling/models/glide_llama.py`).
|
|
Used for pack data that will be used during glimpsing KV Caches of the main model.
|
|
|
|
Args:
|
|
block_tables (torch.Tensor): [num_seqs, max_blocks_per_seq] The block table of KV Caches.
|
|
large_k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_size]
|
|
Blocked key cache of the main model
|
|
large_v_cache (torch.Tensor): Blocked value cache of the main model. It has the same shape as k cache.
|
|
sequence_lengths (torch.Tensor): [num_seqs] Sequence lengths of the current batch.
|
|
"""
|
|
|
|
block_tables: torch.Tensor = None
|
|
large_k_cache: torch.Tensor = None
|
|
large_v_cache: torch.Tensor = None
|
|
sequence_lengths: torch.Tensor = None
|
|
n_spec_tokens: int = 5
|
|
|
|
@property
|
|
def glimpse_ready(self):
|
|
return all(
|
|
attr is not None
|
|
for attr in [self.block_tables, self.large_k_cache, self.large_v_cache, self.sequence_lengths]
|
|
)
|