@ -1,11 +1,7 @@
import enum
from dataclasses import dataclass
from typing import Any , List , Tuple , Union
from typing import Any , List
import torch
from ordered_set import OrderedSet
from colossalai . inference . flash_decoding_utils import FDIntermTensors
from colossalai . logging import get_dist_logger
logger = get_dist_logger ( __name__ )
@ -170,242 +166,6 @@ class Sequence:
)
@dataclass
class BatchInfo :
"""
Information to be passed and used for a batch of sequences .
"""
max_batch_size : int
kv_max_split_num : int
num_heads : int
head_dim : int
sequences_set : OrderedSet [ Sequence ] = None
is_prompts : bool = True
device : torch . device = None
dtype : torch . dtype = None
fd_inter_tensor : FDIntermTensors = None
def __post_init__ ( self ) :
if self . device is None :
self . device = torch . cuda . current_device ( )
if self . sequences_set is None :
self . sequences_set = OrderedSet ( )
if self . fd_inter_tensor is None :
self . fd_inter_tensor = FDIntermTensors ( )
def init_fd_tensors ( self ) :
if not self . fd_inter_tensor . is_initialized :
self . fd_inter_tensor . initialize (
max_batch_size = self . max_batch_size ,
num_attn_heads = self . num_heads ,
kv_max_split_num = self . kv_max_split_num ,
head_dim = self . head_dim ,
dtype = self . dtype ,
device = self . device ,
)
def get_block_table_tensor ( self ) - > None :
tesnor_list = [ ]
block_table = None
assert len ( self . sequences_set ) > 0 , " Batch has not been initialized yet. Please initialize batch first. "
for seq in self . sequences_set :
block_table = seq . block_table
assert (
block_table is not None
) , f " The sequence(request_id { seq . request_id } ) has not initialized the block_table. "
tesnor_list . append ( seq . block_table )
block_table = torch . stack ( tesnor_list )
return block_table
def clear_batch ( self ) - > None :
"""
Clear sequence set and block table if we need to abort this batch .
Prefill : clear sequence set and move them to running batch ( external )
Decoding : mark unfinished sequences as aborted .
"""
if self . is_prompts :
self . sequences_set . clear ( )
else :
for seq in self . sequences_set :
seq . mark_aborted ( )
if seq . check_finish ( ) :
seq . mark_finished ( )
self . sequences_set . clear ( )
def fliter_batch ( self ) - > List [ " Sequence " ] :
"""
Remove completed sentences from a batch .
Returns :
List [ " Sequence " ] : List of finished sequences .
"""
finish_seqs = [ ]
for seq in self . sequences_set :
if seq . check_finish ( ) :
finish_seqs . append ( seq )
for finish_seq in finish_seqs :
self . sequences_set . discard ( finish_seq )
return finish_seqs
def abort_seq ( self , seq : " Sequence " ) - > " Sequence " :
"""
Remove sequence from the batch .
"""
if not seq . check_finish ( ) :
seq . status = RequestStatus . ABORTED
self . sequences_set . discard ( seq )
return seq
def add_seqs ( self , seqs : Union [ Sequence , List [ Sequence ] ] ) - > None :
"""
Add new sequence to batch
Args :
seqs ( List [ " Sequence " ] ) : The list of new sequences .
"""
# covnert single sequence to list
if isinstance ( seqs , Sequence ) :
seqs = [ seqs ]
for seq in seqs :
if seq in self . sequences_set :
logger . warning ( f " The sequence(request_id { seq . request_id } ) is already in sequences_set. " )
continue
self . sequences_set . add ( seq )
def del_seq ( self , seq : Sequence ) - > Sequence :
"""
Delete sequence in batch
"""
self . sequences_set . discard ( seq )
@property
def is_empty ( self ) - > None :
"""
Check whether sequences_set is empty .
"""
return not self . sequences_set
def update_batch_tokens ( self , tokens : Union [ List [ int ] , List [ List [ int ] ] , torch . Tensor ] ) - > None :
"""
Add an output token for each sentence in the batch .
Args :
tokens ( List [ int ] ) : A batch of tokens
"""
if isinstance ( tokens , torch . Tensor ) :
tokens = tokens . tolist ( )
assert self . get_batch_size ( ) == len ( tokens ) , " The number of tokens does not match batch_size. "
for seq , token in zip ( self . sequences_set , tokens ) :
if not isinstance ( token , list ) :
if not isinstance ( token , int ) :
raise TypeError ( f " The token type must be List[int] or int, but got { type ( token ) } . " )
token = [ token ]
seq . output_token_id + = token
seq . check_finish ( )
def get_batch_size ( self ) - > int :
"""
Get batch_size of this batch
"""
return len ( self . sequences_set )
def get_batch_inputs ( self ) - > torch . LongTensor :
"""
Get bacth inputs for forward inference computation .
"""
input_list = [ ]
assert len ( self . sequences_set ) > 0 , " Batch has not been initialized yet. Please initialize batch first. "
for seq in self . sequences_set :
if self . is_prompts :
if seq . output_len > 0 :
input_list . append ( seq . input_token_id + seq . output_token_id )
else :
input_list . append ( seq . input_token_id )
else :
input_list . append ( [ seq . output_token_id [ - 1 ] ] )
max_seq_len = max ( len ( sub_list ) for sub_list in input_list )
# We assume that all the padding_id in seq are the same at present.
return _make_tensor_with_pad ( input_list , max_seq_len , self . sequences_set [ 0 ] . pad_token_id , dtype = torch . int )
def get_1D_inputs ( self ) - > Tuple [ torch . LongTensor , torch . Tensor ] :
"""
Flattening the input tokens .
"""
input_list = [ ]
assert len ( self . sequences_set ) > 0 , " Batch has not been initialized yet. Please initialize batch first. "
for seq in self . sequences_set :
if self . is_prompts :
input_list . extend ( seq . input_token_id )
else :
input_list . append ( seq . output_token_id [ - 1 ] )
return torch . tensor ( input_list , dtype = torch . long , device = self . device )
def get_sequence_lengths ( self ) :
"""
Get the input_len of each sentence in this batch .
"""
len_list = [ ]
assert len ( self . sequences_set ) > 0 , " Batch has not been initialized yet. Please initialize batch first. "
for seq in self . sequences_set :
len_list . append ( seq . sentence_len )
return torch . tensor ( len_list , dtype = torch . int , device = self . device )
def get_attn_mask ( self ) - > torch . Tensor :
"""
Generate and return attention mask .
"""
assert len ( self . sequences_set ) > 0 , " Batch has not been initialized yet. Please initialize batch first. "
past_values = [ ]
# We assume that all the padding_id in seq are the same at present.
padding_id = self . sequences_set [ 0 ] . pad_token_id
for seq in self . sequences_set :
past_values . append ( seq . input_token_id + seq . output_token_id )
max_seq_len = max ( len ( sub_list ) for sub_list in past_values )
attn_mask = _make_tensor_with_pad (
past_values , max_seq_len , self . sequences_set [ 0 ] . pad_token_id , dtype = torch . int , device = self . device
)
return attn_mask . ne ( padding_id ) . long ( )
def __repr__ ( self ) - > str :
return f " (sequences_set= { self . sequences_set } , " f " is_prompts= { self . is_prompts } ) "
def _pad_to_max ( x : List [ int ] , max_len : int , pad : int ) - > List [ int ] :
assert len ( x ) < = max_len
return [ pad ] * ( max_len - len ( x ) ) + x
def _make_tensor_with_pad (
x : Union [ List [ List [ int ] ] , List [ int ] ] ,
max_len : int ,
pad : int ,
dtype : torch . dtype ,
device : Union [ str , torch . device ] = " cuda " ,
pin_memory : bool = False ,
) :
padded_x = [ _pad_to_max ( x_i , max_len , pad ) for x_i in x ]
return torch . tensor ( padded_x , dtype = dtype , device = device , pin_memory = pin_memory and str ( device ) == " cpu " )