[inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
pull/4905/head
Jianghai 1 year ago committed by GitHub
parent 8aed02b957
commit e0757c31fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,346 @@
import collections
from dataclasses import dataclass
from typing import Dict, List , Tuple
import numpy as np
import torch
from colossalai.inference.tensor_parallel import MemoryManager
# make batch infer state an attr of InferBatch
class InferSamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
vocab_size: int = -1,
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
if self.top_k == -1:
self.top_k = vocab_size
return
@dataclass
class InferBatch:
batch_id: int
requests: List
requests_idx_mapping: Dict[int, int]
input_ids: torch.Tensor
all_input_ids: List[List[int]]
input_lengths: List[int]
out_token_id_counts: List
sampling_param_list: List[InferSamplingParams]
nopad_total_token_num: int
nopad_max_len_in_batch: int
nopad_b_loc: torch.Tensor
nopad_b_start_loc: torch.Tensor
nopad_b_seq_len: torch.Tensor
cache_manager: MemoryManager
max_total_len: int
@classmethod
@torch.no_grad()
def init_batch(
cls,
batch_id,
requests,
dtype: torch.dtype,
device: torch.device,
cache_manager: MemoryManager,
vocab_size: int,
max_total_len: int,
) -> 'InferBatch':
input_lengths = []
all_input_ids = []
requests_idx_mapping = {}
out_token_id_counts = []
sampling_param_list = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
# to avoid memory leak , we pre-allocate 12 more space for each batch.
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
for i, r in enumerate(requests):
# request id -> idx in list mapping
requests_idx_mapping[r["request_id"]] = i
tokenized_input = r["input_id"]
input_length = len(tokenized_input)
input_lengths.append(input_length)
all_input_ids.append(tokenized_input)
out_token_id_counts.append(collections.defaultdict(int))
# postprocessor
sampling_param = r["sampling_param"]
sampling_param["vocab_size"] = vocab_size
sampling_param_list.append(InferSamplingParams(**sampling_param))
nopad_total_token_num += input_length
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
if len(requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
else:
input_ids = all_input_ids[0]
# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
return cls(
batch_id=batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=cache_manager,
max_total_len=max_total_len,
)
@torch.no_grad()
def free_self(self) -> None:
"""
Free the memory of the InferBatch itself
"""
remove_index = []
for idx in range(len(self)):
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
@torch.no_grad()
def filter(self, request_ids: List[int]) -> 'InferBatch':
"""
Filter finished batch and return a new InferBatch with left ones.
"""
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
requests_idx_mapping = {}
indices = []
requests = []
all_input_ids = []
input_lengths = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
left_idx = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
left_idx.append(idx)
left_idx_set = set(left_idx)
remove_index = []
for idx in range(len(self)):
if idx not in left_idx_set:
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
nopad_max_len_in_batch = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
indices,
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
]
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(self.input_lengths[idx])
input_ids = self.input_ids[indices]
return InferBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
cache_manager=self.cache_manager,
max_total_len=self.max_total_len,
)
@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2) -> 'InferBatch':
"""
Return megerd new InferBatch
"""
requests = batch1.requests + batch2.requests
requests_idx_mapping = {}
new_batch_size = len(batch1) + len(batch2)
input_ids = batch1.input_ids.new_empty(new_batch_size)
all_input_ids = []
input_lengths = []
out_token_id_counts = []
sampling_param_list = []
cumulative_batch_size = 0
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_start_loc_len_temp = 0
batches = [batch1, batch2]
for i, batch in enumerate(batches):
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
input_ids[start_index:end_index] = batch.input_ids
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
nopad_b_loc[
start_index:end_index,
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths)
out_token_id_counts.extend(batch.out_token_id_counts)
sampling_param_list.extend(batch.sampling_param_list)
# Update
cumulative_batch_size += len(batch)
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
)
return InferBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=batches[0].cache_manager,
max_total_len=max_total_len,
)
def __len__(self):
return len(self.requests)
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
temperatures: List[float] = []
top_ps: List[float] = []
top_ks: List[int] = []
p_token_ids: List[int] = []
p_token_counts: List[int] = []
p_seq_len: List[int] = [
0,
]
p_max_len_in_batch: int = 0
for i, id_to_count in enumerate(self.out_token_id_counts):
sample_param = self.sampling_param_list[i]
presence_penalties.append(sample_param.presence_penalty)
frequency_penalties.append(sample_param.frequency_penalty)
temperatures.append(sample_param.temperature)
top_ps.append(sample_param.top_p)
top_ks.append(sample_param.top_k)
for token_id, count in id_to_count.items():
p_token_ids.append(token_id)
p_token_counts.append(count)
p_seq_len.append(len(id_to_count))
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
return (
presence_penalties,
frequency_penalties,
temperatures,
top_ps,
top_ks,
p_token_ids,
p_token_counts,
p_cumsum_seq_len,
p_max_len_in_batch,
)

@ -0,0 +1,149 @@
from typing import Dict, List, Tuple
from .sampling_params import SamplingParams
class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
self.max_output_len = sample_params.max_new_tokens
self.sample_params = sample_params
self.output_ids = []
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
def to_rpc_obj(self):
return {
"request_id": self.request_id,
"input_id": self.prompt_ids,
"output_len": self.max_output_len,
"sampling_param": self.sample_params.to_dict(),
}
def to_req_detokenization_state(self):
out = ReqDetokenizationState(
self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos
)
if self.output_metadata_list:
out.gen_metadata.update(self.output_metadata_list[-1])
return out
def stop_sequences_matched(self):
# should we add stpp sequences to the sample params?
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
return True
return False
def __repr__(self):
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
class ReqDetokenizationState:
def __init__(
self,
request_id: str,
prompt_ids: List[int],
max_output_len: int,
ignore_eos: bool,
) -> None:
self.request_id = request_id
self.prompt_ids = prompt_ids
self.output_ids = []
self.output_tokens = []
self.output_str = ""
self.sub_texts = []
self.current_sub_text = []
self.max_output_len = max_output_len
self.ignore_eos = ignore_eos
self.gen_metadata = {}
class Batch:
def __init__(self, batch_id, reqs: List[Req]):
self.batch_id = batch_id
self.reqs = reqs
self.id_to_reqs = {req.request_id: req for req in reqs}
def input_tokens(self):
batch_input_tokens = 0
for req in self.reqs:
batch_input_tokens += req.input_len
return batch_input_tokens
def calcu_max_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + req.max_output_len
return tokens
def calcu_used_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + len(req.output_ids)
return tokens
def mark_finished_req(self, eos_id):
has_new_finish = False
for req in self.reqs:
if req.stop_sequences_matched():
req.has_generate_finished = True
has_new_finish = True
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
req.has_generate_finished = True
has_new_finish = True
if len(req.output_ids) >= req.max_output_len or req.aborted:
req.has_generate_finished = True
has_new_finish = True
return has_new_finish
def filter_finished(self):
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
# TODO: the logic of return should be defined here.
unfinished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
def is_clear(self):
return len(self.reqs) == 0
def merge(self, mini_batch):
for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return
def __repr__(self):
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
def __len__(self):
return len(self.reqs)
class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, int, Dict, bool, bool]
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, str, Dict, bool, bool]
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:
def __init__(self, req_id):
self.req_id = req_id

@ -0,0 +1,71 @@
import uuid
from typing import List
import numpy as np
from .io_struct import Batch, Req
class ReqQueue:
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
self.max_total_tokens = max_total_tokens
assert batch_max_tokens is not None
self.batch_max_tokens = batch_max_tokens
self.running_max_req_size = running_max_req_size
self.waiting_req_list: List[Req] = waiting_req_list
def append(self, req):
self.waiting_req_list.append(req)
return
def _init_cache_list(self, current_batch: Batch):
if current_batch is not None:
self.cache_len_list = [
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
for req in current_batch.reqs
]
else:
self.cache_len_list = []
# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req):
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
# NOTE: change here < to <=
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
def generate_new_batch(self, current_batch: Batch = None):
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
return None
self._init_cache_list(current_batch)
can_run_list = []
new_batch_total_tokens = 0
aborted_count = 0
for req in self.waiting_req_list:
flag = self._can_add_new_req(req)
if req.aborted:
aborted_count += 1
continue
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
can_run_list.append(req)
new_batch_total_tokens += req.input_len
else:
break
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().hex, can_run_list)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None
def __len__(self):
return self.waiting_req_list.__len__()

@ -0,0 +1,82 @@
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 16,
stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
self.top_k = 1
if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search
self.temperature = 1.0
self.top_k = 1
return
def verify(self):
if self.presence_penalty < 0.0:
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
if self.frequency_penalty < 0.0:
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
if self.temperature <= 0.0:
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
if self.top_p <= 0.0 or self.top_p > 1.0:
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
return
def stop_sentences_to_token_ids(self, tokenizer):
if self.stop_sequences is None:
self.stop_sequences = []
else:
if isinstance(self.stop_sequences, str):
self.stop_sequences = [self.stop_sequences]
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return
def to_dict(self):
ret = {}
ret["do_sample"] = self.do_sample
ret["presence_penalty"] = self.presence_penalty
ret["frequency_penalty"] = self.frequency_penalty
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
# if self.ignore_eos is not None:
# ret["ignore_eos"] = self.ignore_eos
# if self.max_tokens is not None:
# ret["max_tokens"] = self.max_tokens
return ret

@ -0,0 +1,43 @@
import time
class Stats:
def __init__(self, log_status, log_stats_interval) -> None:
self.log_stats = log_status
self.log_stats_interval = log_stats_interval
self.last_log_time = time.time()
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
return
def count_prompt_tokens(self, run_batch):
if self.log_stats:
tokens = run_batch.input_tokens()
self.prompt_tokens += tokens
self.all_tokens += tokens
return
def count_output_tokens(self, run_batch):
if self.log_stats:
tokens = len(run_batch.reqs)
self.output_tokens += tokens
self.all_tokens += tokens
return
def print_stats(self):
if not self.log_stats:
return
now = time.time()
if now - self.last_log_time > self.log_stats_interval:
print(
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
)
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
self.last_log_time = now
return

@ -0,0 +1,243 @@
import time
from typing import List
from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
from .dynamic_batching.req_queue import ReqQueue
from .dynamic_batching.sampling_params import SamplingParams
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine
class DynamicBatchManager:
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
eos_id,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
self.engine = tp_engine
self.max_total_token_num = max_total_token_num
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
# all the inputs should be put into req_queue: waiting req list
self.running_batch: Batch = running_batch
self.eos_id = eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = 10
self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
req = Req(request_id, prompt_ids, sampling_params)
self.req_queue.append(req)
return
def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
for req in self.req_queue.waiting_req_list:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
return
def loop_for_fwd(self):
"""
The main loop for a dynamic batching process.
"""
counter_count = 0
while self.running_batch is not None or self.req_queue.waiting_req_list:
self._step()
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
print(
"current batch size:",
len(self.running_batch.reqs),
"token used ratio:",
self.running_batch.calcu_used_tokens() / self.max_total_token_num,
)
self.stats_tool.print_stats()
if self.running_batch is None:
time.sleep(0.1) # 10ms
def _step(self):
"""
Logic for handling requests
"""
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
return
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0
else:
self.stats_tool.count_output_tokens(self.running_batch)
self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
def _init_batch(self, batch: Batch, dtype="fp16"):
reqs = [r.to_rpc_obj() for r in batch.reqs]
batch_id = batch.batch_id
import torch
if dtype == "fp16":
dtype = torch.float16
else:
assert False, "error dtype"
batch_data = InferBatch.init_batch(
batch_id,
reqs,
dtype,
torch.cuda.current_device(),
self.engine.cache_manager,
self.engine.model.config.vocab_size,
self.engine.max_input_len + self.engine.max_output_len,
)
self.engine.cache[batch_id] = batch_data
def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)
# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
# delete finished reqs
def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id)
self._handle_finish_req(batch, has_new_finished_req)
def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
req_id_list = [r.request_id for r in batch.reqs]
batch = self.engine.cache.pop(batch_id)
filter_batch = batch.filter(req_id_list)
del batch
self.engine.cache[batch_id] = filter_batch
def _merge_batch(self, batch1, batch2):
"""
Merge new mini batch into running batch.
"""
batch1 = self.engine.cache.pop(batch1.batch_id)
batch2 = self.engine.cache.pop(batch2.batch_id)
m_batch = InferBatch.merge(batch1, batch2)
self.engine.cache[batch1.batch_id] = m_batch
del batch1
del batch2
def _remove_batch(self, batch):
"""
Remove finished batch.
"""
batch = self.engine.cache.pop(batch.batch_id)
batch.free_self()
del batch
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None
def _add_token_id_to_req(self, batch: Batch, req_ans):
for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
req = batch.id_to_reqs[req_id]
req.output_ids.append(new_token_id)
req.output_metadata_list.append(new_gen_metadata)
return
def clean_up(self):
# this logic should be implemented in the future.
pass
def start_dynamic_batching(args, tp_engine, waiting_req_list):
# try:
batch_manager = DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)
# except Exception:
# batch_manager.clean_up()
# raise
batch_manager.loop_for_fwd()
return

@ -1,7 +1,6 @@
from typing import Any, Callable, List, Optional, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from transformers import BloomForCausalLM, LlamaForCausalLM
from transformers.generation import GenerationConfig
@ -14,6 +13,8 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
# from dynamic_batching.infer_batch import InferBatch
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = [
@ -90,6 +91,8 @@ class TPInferEngine:
self.shard_config = shard_config
self.model = None
self.cache = {}
# optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device))
@ -116,13 +119,15 @@ class TPInferEngine:
def _post_init_gptq_buffer(self, model: nn.Module) -> None:
from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear
HAS_GPTQ_CUDA = False
try:
from colossalai.kernel.op_builder.gptq import GPTQBuilder
gptq_cuda = GPTQBuilder().load()
HAS_GPTQ_CUDA = True
except ImportError:
warnings.warn('CUDA gptq is not installed')
warnings.warn("CUDA gptq is not installed")
HAS_GPTQ_CUDA = False
for name, submodule in model.named_modules():
@ -130,8 +135,9 @@ class TPInferEngine:
self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
if self.use_act_order:
self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures,
submodule.outfeatures)
self.max_inner_outer_dim = max(
self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
)
self.bits = submodule.bits
if not (HAS_GPTQ_CUDA and self.bits == 4):
return
@ -141,15 +147,16 @@ class TPInferEngine:
max_input_len = self.max_input_len
# The temp_state buffer is required to reorder X in the act-order case.
# The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim),
dtype=torch.float16,
device=torch.cuda.current_device())
self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size),
dtype=torch.float16,
device=torch.cuda.current_device())
gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer,
self.gptq_temp_dq_buffer)
self.gptq_temp_state_buffer = torch.zeros(
(max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
)
self.gptq_temp_dq_buffer = torch.zeros(
(1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
)
gptq_cuda.prepare_buffers(
torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
)
# Using the default from exllama repo here.
matmul_recons_thd = 8
matmul_fused_remap = False
@ -270,7 +277,6 @@ class TPInferEngine:
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
@ -304,6 +310,7 @@ class TPInferEngine:
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
return batch_infer_state
@torch.no_grad()
@ -367,6 +374,86 @@ class TPInferEngine:
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1
@torch.no_grad()
def forward(self, batch_id, is_prefill):
"""
Forward is used in Dynamic Batching Manager
"""
batch = self.cache.pop(batch_id)
if is_prefill:
input_ = torch.tensor(batch.all_input_ids).cuda()
else:
input_ = batch.input_ids.reshape(len(batch), 1)
batch_args = {
"batch_size": len(batch),
"max_len_in_batch": batch.nopad_max_len_in_batch,
"block_loc": batch.nopad_b_loc,
"start_loc": batch.nopad_b_start_loc,
"seq_len": batch.nopad_b_seq_len,
"cache_manager": batch.cache_manager,
"is_context_stage": is_prefill,
}
infer_state = BatchInferState(**batch_args)
model = self.model
if isinstance(model, LlamaForCausalLM):
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
setattr(model, "infer_state", infer_state)
output = self.model.forward(input_ids=input_)
logits = output.logits
# bsz, seq_len, vocab_size
prob_out = torch.softmax(
logits[
:,
-1,
],
dim=-1,
).squeeze(1)
# prob_out: bsz, vocab_size
predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True)
prob_out = torch.log(prob_out).detach().cpu().numpy()
predict_ids = predict_ids.detach().cpu().numpy()
# [ batch_size, 1 ]
output_dict = {}
new_input_ids = []
for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate(
zip(batch.requests, batch.all_input_ids, predict_ids, prob_out)
):
next_token_id = int(next_token_id)
next_token_logprob = next_token_logprob[next_token_id]
# all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda")
all_input_ids.append(next_token_id)
# all_input_ids_tensor = None
new_input_ids.append(next_token_id)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] += 1
batch.out_token_id_counts[i][next_token_id] += 1
metadata = {
"id": int(next_token_id),
"logprob": float(next_token_logprob),
}
output_dict[r["request_id"]] = (int(next_token_id), metadata)
batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda()
batch.nopad_total_token_num += len(batch)
batch.nopad_max_len_in_batch += 1
self.cache[batch.batch_id] = batch
return output_dict
@torch.no_grad()
def _prefill_batch(self, batch_id):
return self.forward(batch_id, is_prefill=True)
@torch.no_grad()
def _decode_batch(self, batch_id):
return self.forward(batch_id, is_prefill=False)
# might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text

@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None))
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)

@ -62,12 +62,11 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
@ -78,15 +77,12 @@ class LlamaInferenceForwards:
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
# NOT READY FOR PRIME TIME
# dummy but work, revise it
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@ -106,23 +102,23 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
@ -134,6 +130,7 @@ class LlamaInferenceForwards:
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
@ -145,7 +142,7 @@ class LlamaInferenceForwards:
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
@ -160,7 +157,6 @@ class LlamaInferenceForwards:
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
# NOTE: modify here for passing args to decoder layer
@ -184,7 +180,7 @@ class LlamaInferenceForwards:
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
if not return_dict:
@ -211,7 +207,6 @@ class LlamaInferenceForwards:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
@ -267,11 +262,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin)
@ -282,7 +274,6 @@ class LlamaInferenceForwards:
if infer_state.is_context_stage:
# first token generation
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
@ -291,9 +282,7 @@ class LlamaInferenceForwards:
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
llama_context_attn_fwd(
query_states,
key_states,
@ -301,7 +290,7 @@ class LlamaInferenceForwards:
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
if infer_state.decode_is_contiguous:
@ -338,7 +327,7 @@ class LlamaInferenceForwards:
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)

@ -2,7 +2,6 @@ try:
import triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")

@ -51,7 +51,6 @@ if HAS_TRITON:
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr,
dest_index_ptr,

@ -27,8 +27,10 @@ if HAS_LLAMA:
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
input_ids = torch.Tensor(
[[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm

@ -0,0 +1,94 @@
import pytest
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import DynamicBatchManager
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
BATCH_SIZE = 2
MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16
def run():
sampling_params = SamplingParams()
req1 = Req(0, [1], sampling_params)
req2 = Req(1, [2], sampling_params)
req3 = Req(2, [3], sampling_params)
# req 1-3 are initiliazed as token forward requests
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
# init model and tp engine
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
dynamic_batch_manager = DynamicBatchManager(
tp_engine=infer_engine,
max_total_token_num=42,
batch_max_tokens=42,
eos_id=0,
log_stats=False,
log_stats_interval=10,
waiting_req_list=waiting_list,
)
before_add = len(dynamic_batch_manager.req_queue)
# test add req function
dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id)
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
# test abort function
dynamic_batch_manager.abort(req4.request_id)
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
batch = dynamic_batch_manager.req_queue.generate_new_batch()
assert len(batch) == 2
dynamic_batch_manager._init_batch(batch)
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
batch.reqs[0].has_generate_finished = True
# filter one finished
batch.filter_finished()
dynamic_batch_manager._filter_batch(batch)
assert len(dynamic_batch_manager.engine.cache) == 1
# test merge batch
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
assert len(new_batch) == 1
dynamic_batch_manager._init_batch(new_batch)
dynamic_batch_manager._merge_batch(batch, new_batch)
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
def check_dynamic_batching_manager(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching_manager():
spawn(check_dynamic_batching_manager, 1)
if __name__ == "__main__":
test_dynamic_batching_manager()

@ -0,0 +1,70 @@
import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from dataclasses import dataclass
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
MAX_BATCH_SIZE = 2
MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@dataclass
class args:
max_total_token_num: int
batch_max_tokens: int
eos_id: int
disable_log_stats: bool
log_stats_interval: int
def run():
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
sampling_params = SamplingParams()
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
waiting_list.append(req4)
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
def check_dynamic_forward(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching():
spawn(check_dynamic_forward, TP_SIZE)
if __name__ == "__main__":
test_dynamic_batching()

@ -38,7 +38,6 @@ def run_llama_test(test_config):
enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
init_to_get_rotary(model.model, base=10000)
generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
input_tokens = {

Loading…
Cancel
Save