mirror of https://github.com/hpcaitech/ColossalAI
[inference] Dynamic Batching for Single and Multiple GPUs (#4831)
* finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com>pull/4905/head
parent
8aed02b957
commit
e0757c31fb
@ -0,0 +1,346 @@
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List , Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.inference.tensor_parallel import MemoryManager
|
||||
|
||||
# make batch infer state an attr of InferBatch
|
||||
|
||||
|
||||
class InferSamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
do_sample: bool = False,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
vocab_size: int = -1,
|
||||
) -> None:
|
||||
self.do_sample = do_sample
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
if self.top_k == -1:
|
||||
self.top_k = vocab_size
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferBatch:
|
||||
batch_id: int
|
||||
requests: List
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
input_ids: torch.Tensor
|
||||
|
||||
all_input_ids: List[List[int]]
|
||||
input_lengths: List[int]
|
||||
|
||||
out_token_id_counts: List
|
||||
sampling_param_list: List[InferSamplingParams]
|
||||
|
||||
nopad_total_token_num: int
|
||||
nopad_max_len_in_batch: int
|
||||
nopad_b_loc: torch.Tensor
|
||||
nopad_b_start_loc: torch.Tensor
|
||||
nopad_b_seq_len: torch.Tensor
|
||||
cache_manager: MemoryManager
|
||||
max_total_len: int
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def init_batch(
|
||||
cls,
|
||||
batch_id,
|
||||
requests,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_manager: MemoryManager,
|
||||
vocab_size: int,
|
||||
max_total_len: int,
|
||||
) -> 'InferBatch':
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
# to avoid memory leak , we pre-allocate 12 more space for each batch.
|
||||
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
|
||||
for i, r in enumerate(requests):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r["request_id"]] = i
|
||||
|
||||
tokenized_input = r["input_id"]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
all_input_ids.append(tokenized_input)
|
||||
out_token_id_counts.append(collections.defaultdict(int))
|
||||
|
||||
# postprocessor
|
||||
sampling_param = r["sampling_param"]
|
||||
sampling_param["vocab_size"] = vocab_size
|
||||
sampling_param_list.append(InferSamplingParams(**sampling_param))
|
||||
|
||||
nopad_total_token_num += input_length
|
||||
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
|
||||
|
||||
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
|
||||
if len(requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
return cls(
|
||||
batch_id=batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def free_self(self) -> None:
|
||||
"""
|
||||
Free the memory of the InferBatch itself
|
||||
"""
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def filter(self, request_ids: List[int]) -> 'InferBatch':
|
||||
"""
|
||||
Filter finished batch and return a new InferBatch with left ones.
|
||||
"""
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
requests_idx_mapping = {}
|
||||
indices = []
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
|
||||
left_idx = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
left_idx.append(idx)
|
||||
|
||||
left_idx_set = set(left_idx)
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
if idx not in left_idx_set:
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
nopad_max_len_in_batch = 0
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
|
||||
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
|
||||
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
|
||||
|
||||
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
|
||||
indices,
|
||||
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
requests.append(self.requests[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
input_lengths.append(self.input_lengths[idx])
|
||||
|
||||
input_ids = self.input_ids[indices]
|
||||
|
||||
return InferBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
|
||||
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
|
||||
cache_manager=self.cache_manager,
|
||||
max_total_len=self.max_total_len,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def merge(cls, batch1, batch2) -> 'InferBatch':
|
||||
"""
|
||||
Return megerd new InferBatch
|
||||
"""
|
||||
requests = batch1.requests + batch2.requests
|
||||
requests_idx_mapping = {}
|
||||
new_batch_size = len(batch1) + len(batch2)
|
||||
|
||||
input_ids = batch1.input_ids.new_empty(new_batch_size)
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
cumulative_batch_size = 0
|
||||
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
|
||||
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
|
||||
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
|
||||
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_start_loc_len_temp = 0
|
||||
batches = [batch1, batch2]
|
||||
for i, batch in enumerate(batches):
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
|
||||
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
|
||||
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
|
||||
nopad_b_loc[
|
||||
start_index:end_index,
|
||||
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
|
||||
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
out_token_id_counts.extend(batch.out_token_id_counts)
|
||||
sampling_param_list.extend(batch.sampling_param_list)
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
|
||||
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
|
||||
)
|
||||
return InferBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=batches[0].cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
temperatures: List[float] = []
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
p_token_ids: List[int] = []
|
||||
p_token_counts: List[int] = []
|
||||
p_seq_len: List[int] = [
|
||||
0,
|
||||
]
|
||||
p_max_len_in_batch: int = 0
|
||||
for i, id_to_count in enumerate(self.out_token_id_counts):
|
||||
sample_param = self.sampling_param_list[i]
|
||||
presence_penalties.append(sample_param.presence_penalty)
|
||||
frequency_penalties.append(sample_param.frequency_penalty)
|
||||
temperatures.append(sample_param.temperature)
|
||||
top_ps.append(sample_param.top_p)
|
||||
top_ks.append(sample_param.top_k)
|
||||
|
||||
for token_id, count in id_to_count.items():
|
||||
p_token_ids.append(token_id)
|
||||
p_token_counts.append(count)
|
||||
p_seq_len.append(len(id_to_count))
|
||||
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
|
||||
|
||||
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
|
||||
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
|
||||
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
|
||||
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
|
||||
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
|
||||
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
|
||||
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
|
||||
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
|
||||
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
|
||||
return (
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
p_token_ids,
|
||||
p_token_counts,
|
||||
p_cumsum_seq_len,
|
||||
p_max_len_in_batch,
|
||||
)
|
@ -0,0 +1,149 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from .sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
|
||||
self.request_id = request_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.input_len = len(prompt_ids)
|
||||
self.max_output_len = sample_params.max_new_tokens
|
||||
self.sample_params = sample_params
|
||||
self.output_ids = []
|
||||
self.output_metadata_list = []
|
||||
self.has_generate_finished = False
|
||||
self.aborted = False
|
||||
|
||||
def to_rpc_obj(self):
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"input_id": self.prompt_ids,
|
||||
"output_len": self.max_output_len,
|
||||
"sampling_param": self.sample_params.to_dict(),
|
||||
}
|
||||
|
||||
def to_req_detokenization_state(self):
|
||||
out = ReqDetokenizationState(
|
||||
self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos
|
||||
)
|
||||
if self.output_metadata_list:
|
||||
out.gen_metadata.update(self.output_metadata_list[-1])
|
||||
return out
|
||||
|
||||
def stop_sequences_matched(self):
|
||||
# should we add stpp sequences to the sample params?
|
||||
if self.sample_params.stop_sequences is not None:
|
||||
for stop_token_ids in self.sample_params.stop_sequences:
|
||||
stop_len = len(stop_token_ids)
|
||||
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
|
||||
|
||||
|
||||
class ReqDetokenizationState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_ids: List[int],
|
||||
max_output_len: int,
|
||||
ignore_eos: bool,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.output_ids = []
|
||||
self.output_tokens = []
|
||||
self.output_str = ""
|
||||
self.sub_texts = []
|
||||
self.current_sub_text = []
|
||||
self.max_output_len = max_output_len
|
||||
self.ignore_eos = ignore_eos
|
||||
self.gen_metadata = {}
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(self, batch_id, reqs: List[Req]):
|
||||
self.batch_id = batch_id
|
||||
self.reqs = reqs
|
||||
self.id_to_reqs = {req.request_id: req for req in reqs}
|
||||
|
||||
def input_tokens(self):
|
||||
batch_input_tokens = 0
|
||||
for req in self.reqs:
|
||||
batch_input_tokens += req.input_len
|
||||
return batch_input_tokens
|
||||
|
||||
def calcu_max_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + req.max_output_len
|
||||
return tokens
|
||||
|
||||
def calcu_used_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + len(req.output_ids)
|
||||
return tokens
|
||||
|
||||
def mark_finished_req(self, eos_id):
|
||||
has_new_finish = False
|
||||
for req in self.reqs:
|
||||
if req.stop_sequences_matched():
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if len(req.output_ids) >= req.max_output_len or req.aborted:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
return has_new_finish
|
||||
|
||||
def filter_finished(self):
|
||||
"""
|
||||
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
|
||||
"""
|
||||
# TODO: the logic of return should be defined here.
|
||||
unfinished_req = []
|
||||
for req in self.reqs:
|
||||
if not req.has_generate_finished:
|
||||
unfinished_req.append(req)
|
||||
self.reqs = unfinished_req
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
|
||||
def is_clear(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def merge(self, mini_batch):
|
||||
for _req in mini_batch.reqs:
|
||||
self.reqs.append(_req)
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
|
||||
|
||||
def __len__(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchTokenIdOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, int, Dict, bool, bool]
|
||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class BatchStrOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, str, Dict, bool, bool]
|
||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class AbortReq:
|
||||
def __init__(self, req_id):
|
||||
self.req_id = req_id
|
@ -0,0 +1,71 @@
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .io_struct import Batch, Req
|
||||
|
||||
|
||||
class ReqQueue:
|
||||
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
|
||||
self.max_total_tokens = max_total_tokens
|
||||
assert batch_max_tokens is not None
|
||||
self.batch_max_tokens = batch_max_tokens
|
||||
self.running_max_req_size = running_max_req_size
|
||||
self.waiting_req_list: List[Req] = waiting_req_list
|
||||
|
||||
def append(self, req):
|
||||
self.waiting_req_list.append(req)
|
||||
return
|
||||
|
||||
def _init_cache_list(self, current_batch: Batch):
|
||||
if current_batch is not None:
|
||||
self.cache_len_list = [
|
||||
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
|
||||
for req in current_batch.reqs
|
||||
]
|
||||
else:
|
||||
self.cache_len_list = []
|
||||
|
||||
# @calculate_time(show=True, min_cost_ms=0.1)
|
||||
def _can_add_new_req(self, req):
|
||||
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
|
||||
self.cache_len_list.sort(key=lambda x: -x[1])
|
||||
|
||||
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
|
||||
# assert left_out_len_array.min() >= 0
|
||||
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
|
||||
cum_run_len_array = np.cumsum(has_run_len_array)
|
||||
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
|
||||
|
||||
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
|
||||
# NOTE: change here < to <=
|
||||
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
|
||||
|
||||
def generate_new_batch(self, current_batch: Batch = None):
|
||||
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
|
||||
return None
|
||||
self._init_cache_list(current_batch)
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
aborted_count = 0
|
||||
for req in self.waiting_req_list:
|
||||
flag = self._can_add_new_req(req)
|
||||
if req.aborted:
|
||||
aborted_count += 1
|
||||
continue
|
||||
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += req.input_len
|
||||
else:
|
||||
break
|
||||
|
||||
if len(can_run_list) != 0:
|
||||
new_batch = Batch(uuid.uuid4().hex, can_run_list)
|
||||
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
|
||||
return new_batch
|
||||
else:
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return self.waiting_req_list.__len__()
|
@ -0,0 +1,82 @@
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
do_sample: bool = False,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1, # -1 is for all
|
||||
ignore_eos: bool = False,
|
||||
max_new_tokens: int = 16,
|
||||
stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation
|
||||
) -> None:
|
||||
self.do_sample = do_sample
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_sequences = stop_sequences
|
||||
if self.do_sample == False:
|
||||
self.temperature = 1.0
|
||||
self.top_p = 1.0
|
||||
self.top_k = 1
|
||||
if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
return
|
||||
|
||||
def verify(self):
|
||||
if self.presence_penalty < 0.0:
|
||||
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
|
||||
if self.frequency_penalty < 0.0:
|
||||
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
|
||||
if self.temperature <= 0.0:
|
||||
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
|
||||
if self.top_p <= 0.0 or self.top_p > 1.0:
|
||||
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
|
||||
if self.max_new_tokens < 1:
|
||||
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
|
||||
return
|
||||
|
||||
def stop_sentences_to_token_ids(self, tokenizer):
|
||||
if self.stop_sequences is None:
|
||||
self.stop_sequences = []
|
||||
else:
|
||||
if isinstance(self.stop_sequences, str):
|
||||
self.stop_sequences = [self.stop_sequences]
|
||||
new_stop_sequences = []
|
||||
for stop_str in self.stop_sequences:
|
||||
stop_str_ids = tokenizer.encode(stop_str)
|
||||
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
|
||||
stop_str_ids = stop_str_ids[1:]
|
||||
if len(stop_str_ids) > 0:
|
||||
new_stop_sequences.append(stop_str_ids)
|
||||
self.stop_sequences = new_stop_sequences
|
||||
return
|
||||
|
||||
def to_dict(self):
|
||||
ret = {}
|
||||
ret["do_sample"] = self.do_sample
|
||||
ret["presence_penalty"] = self.presence_penalty
|
||||
ret["frequency_penalty"] = self.frequency_penalty
|
||||
ret["temperature"] = self.temperature
|
||||
ret["top_p"] = self.top_p
|
||||
ret["top_k"] = self.top_k
|
||||
# if self.ignore_eos is not None:
|
||||
# ret["ignore_eos"] = self.ignore_eos
|
||||
# if self.max_tokens is not None:
|
||||
# ret["max_tokens"] = self.max_tokens
|
||||
return ret
|
@ -0,0 +1,43 @@
|
||||
import time
|
||||
|
||||
|
||||
class Stats:
|
||||
def __init__(self, log_status, log_stats_interval) -> None:
|
||||
self.log_stats = log_status
|
||||
self.log_stats_interval = log_stats_interval
|
||||
self.last_log_time = time.time()
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
return
|
||||
|
||||
def count_prompt_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = run_batch.input_tokens()
|
||||
self.prompt_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def count_output_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = len(run_batch.reqs)
|
||||
self.output_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def print_stats(self):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self.last_log_time > self.log_stats_interval:
|
||||
print(
|
||||
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
|
||||
)
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
self.last_log_time = now
|
||||
return
|
@ -0,0 +1,243 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from .dynamic_batching.infer_batch import InferBatch
|
||||
from .dynamic_batching.io_struct import Batch, Req
|
||||
from .dynamic_batching.req_queue import ReqQueue
|
||||
from .dynamic_batching.sampling_params import SamplingParams
|
||||
from .dynamic_batching.stats import Stats
|
||||
from .tensor_parallel import TPInferEngine
|
||||
|
||||
|
||||
class DynamicBatchManager:
|
||||
def __init__(
|
||||
self,
|
||||
tp_engine: TPInferEngine,
|
||||
max_total_token_num,
|
||||
batch_max_tokens,
|
||||
eos_id,
|
||||
log_stats=True,
|
||||
log_stats_interval=10,
|
||||
running_batch: Batch = None,
|
||||
waiting_req_list: List = [],
|
||||
):
|
||||
"""
|
||||
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
|
||||
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
|
||||
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
||||
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
||||
eos_id : The end token of a seq
|
||||
log_stats : whether to log stats
|
||||
log_stats_interval : log stats interval
|
||||
running_batch : running batch
|
||||
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
|
||||
"""
|
||||
self.engine = tp_engine
|
||||
self.max_total_token_num = max_total_token_num
|
||||
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
|
||||
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
|
||||
# all the inputs should be put into req_queue: waiting req list
|
||||
|
||||
self.running_batch: Batch = running_batch
|
||||
self.eos_id = eos_id
|
||||
self.has_wait_tokens = 0
|
||||
self.max_wait_tokens = 10
|
||||
|
||||
self.stats_tool = Stats(log_stats, log_stats_interval)
|
||||
self.mem_usage_interval = log_stats_interval * 2
|
||||
|
||||
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
|
||||
"""
|
||||
Add new request to req queue, during initialization all requests are held in waiting list.
|
||||
"""
|
||||
req = Req(request_id, prompt_ids, sampling_params)
|
||||
self.req_queue.append(req)
|
||||
return
|
||||
|
||||
def abort(self, request_id):
|
||||
if self.running_batch is not None:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.request_id == request_id:
|
||||
req.has_generate_finished = True
|
||||
req.aborted = True
|
||||
for req in self.req_queue.waiting_req_list:
|
||||
if req.request_id == request_id:
|
||||
req.has_generate_finished = True
|
||||
req.aborted = True
|
||||
return
|
||||
|
||||
def loop_for_fwd(self):
|
||||
"""
|
||||
The main loop for a dynamic batching process.
|
||||
"""
|
||||
counter_count = 0
|
||||
while self.running_batch is not None or self.req_queue.waiting_req_list:
|
||||
self._step()
|
||||
counter_count += 1
|
||||
if self.running_batch is not None:
|
||||
if counter_count % self.mem_usage_interval == 0:
|
||||
print(
|
||||
"current batch size:",
|
||||
len(self.running_batch.reqs),
|
||||
"token used ratio:",
|
||||
self.running_batch.calcu_used_tokens() / self.max_total_token_num,
|
||||
)
|
||||
self.stats_tool.print_stats()
|
||||
|
||||
if self.running_batch is None:
|
||||
time.sleep(0.1) # 10ms
|
||||
|
||||
def _step(self):
|
||||
"""
|
||||
Logic for handling requests
|
||||
"""
|
||||
|
||||
if self.running_batch is None:
|
||||
new_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_batch)
|
||||
self.running_batch = new_batch
|
||||
self._prefill_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens = 0
|
||||
return
|
||||
|
||||
if self.has_wait_tokens < self.max_wait_tokens:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
return
|
||||
else:
|
||||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_mini_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
||||
self._prefill_batch(new_mini_batch)
|
||||
if not new_mini_batch.is_clear():
|
||||
self._merge_batch(self.running_batch, new_mini_batch)
|
||||
self.running_batch.merge(new_mini_batch)
|
||||
self.has_wait_tokens = 0
|
||||
else:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
|
||||
return
|
||||
|
||||
def _init_batch(self, batch: Batch, dtype="fp16"):
|
||||
reqs = [r.to_rpc_obj() for r in batch.reqs]
|
||||
batch_id = batch.batch_id
|
||||
|
||||
import torch
|
||||
|
||||
if dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
assert False, "error dtype"
|
||||
|
||||
batch_data = InferBatch.init_batch(
|
||||
batch_id,
|
||||
reqs,
|
||||
dtype,
|
||||
torch.cuda.current_device(),
|
||||
self.engine.cache_manager,
|
||||
self.engine.model.config.vocab_size,
|
||||
self.engine.max_input_len + self.engine.max_output_len,
|
||||
)
|
||||
self.engine.cache[batch_id] = batch_data
|
||||
|
||||
def _prefill_batch(self, batch):
|
||||
"""
|
||||
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
||||
"""
|
||||
self._init_batch(batch)
|
||||
|
||||
# TODO: figure out if cache and batch id is needed
|
||||
ans = self.engine._prefill_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||
self._handle_finish_req(batch, has_new_finished_req)
|
||||
# delete finished reqs
|
||||
|
||||
def _decode_batch(self, batch: Batch):
|
||||
"""
|
||||
Decoding process
|
||||
"""
|
||||
ans = self.engine._decode_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||
self._handle_finish_req(batch, has_new_finished_req)
|
||||
|
||||
def _filter_batch(self, batch: Batch):
|
||||
batch_id = batch.batch_id
|
||||
req_id_list = [r.request_id for r in batch.reqs]
|
||||
batch = self.engine.cache.pop(batch_id)
|
||||
filter_batch = batch.filter(req_id_list)
|
||||
del batch
|
||||
self.engine.cache[batch_id] = filter_batch
|
||||
|
||||
def _merge_batch(self, batch1, batch2):
|
||||
"""
|
||||
Merge new mini batch into running batch.
|
||||
"""
|
||||
batch1 = self.engine.cache.pop(batch1.batch_id)
|
||||
batch2 = self.engine.cache.pop(batch2.batch_id)
|
||||
|
||||
m_batch = InferBatch.merge(batch1, batch2)
|
||||
self.engine.cache[batch1.batch_id] = m_batch
|
||||
del batch1
|
||||
del batch2
|
||||
|
||||
def _remove_batch(self, batch):
|
||||
"""
|
||||
Remove finished batch.
|
||||
"""
|
||||
batch = self.engine.cache.pop(batch.batch_id)
|
||||
batch.free_self()
|
||||
del batch
|
||||
|
||||
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
||||
if has_new_finished_req:
|
||||
batch.filter_finished()
|
||||
if batch.is_clear():
|
||||
self._remove_batch(batch)
|
||||
else:
|
||||
self._filter_batch(batch)
|
||||
|
||||
def _filter_runing_batch(self):
|
||||
if self.running_batch is not None and self.running_batch.is_clear():
|
||||
self.running_batch = None
|
||||
|
||||
def _add_token_id_to_req(self, batch: Batch, req_ans):
|
||||
for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
|
||||
req = batch.id_to_reqs[req_id]
|
||||
req.output_ids.append(new_token_id)
|
||||
req.output_metadata_list.append(new_gen_metadata)
|
||||
return
|
||||
|
||||
def clean_up(self):
|
||||
# this logic should be implemented in the future.
|
||||
pass
|
||||
|
||||
|
||||
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||
# try:
|
||||
batch_manager = DynamicBatchManager(
|
||||
tp_engine=tp_engine,
|
||||
max_total_token_num=args.max_total_token_num,
|
||||
batch_max_tokens=args.batch_max_tokens,
|
||||
eos_id=args.eos_id,
|
||||
log_stats=not args.disable_log_stats,
|
||||
log_stats_interval=args.log_stats_interval,
|
||||
waiting_req_list=waiting_req_list,
|
||||
)
|
||||
|
||||
# except Exception:
|
||||
# batch_manager.clean_up()
|
||||
# raise
|
||||
|
||||
batch_manager.loop_for_fwd()
|
||||
return
|
@ -0,0 +1,94 @@
|
||||
import pytest
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import DynamicBatchManager
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 5
|
||||
MAX_OUTPUT_LEN = 16
|
||||
|
||||
|
||||
def run():
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [1], sampling_params)
|
||||
req2 = Req(1, [2], sampling_params)
|
||||
req3 = Req(2, [3], sampling_params)
|
||||
# req 1-3 are initiliazed as token forward requests
|
||||
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
|
||||
# init model and tp engine
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
dynamic_batch_manager = DynamicBatchManager(
|
||||
tp_engine=infer_engine,
|
||||
max_total_token_num=42,
|
||||
batch_max_tokens=42,
|
||||
eos_id=0,
|
||||
log_stats=False,
|
||||
log_stats_interval=10,
|
||||
waiting_req_list=waiting_list,
|
||||
)
|
||||
before_add = len(dynamic_batch_manager.req_queue)
|
||||
|
||||
# test add req function
|
||||
dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id)
|
||||
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
|
||||
|
||||
# test abort function
|
||||
dynamic_batch_manager.abort(req4.request_id)
|
||||
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
|
||||
|
||||
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
|
||||
batch = dynamic_batch_manager.req_queue.generate_new_batch()
|
||||
assert len(batch) == 2
|
||||
|
||||
dynamic_batch_manager._init_batch(batch)
|
||||
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
|
||||
|
||||
batch.reqs[0].has_generate_finished = True
|
||||
# filter one finished
|
||||
batch.filter_finished()
|
||||
dynamic_batch_manager._filter_batch(batch)
|
||||
assert len(dynamic_batch_manager.engine.cache) == 1
|
||||
|
||||
# test merge batch
|
||||
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
|
||||
assert len(new_batch) == 1
|
||||
dynamic_batch_manager._init_batch(new_batch)
|
||||
dynamic_batch_manager._merge_batch(batch, new_batch)
|
||||
|
||||
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
|
||||
|
||||
|
||||
def check_dynamic_batching_manager(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching_manager():
|
||||
spawn(check_dynamic_batching_manager, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching_manager()
|
@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from dataclasses import dataclass
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import start_dynamic_batching
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
MAX_BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 5
|
||||
MAX_OUTPUT_LEN = 16
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
@dataclass
|
||||
class args:
|
||||
max_total_token_num: int
|
||||
batch_max_tokens: int
|
||||
eos_id: int
|
||||
disable_log_stats: bool
|
||||
log_stats_interval: int
|
||||
|
||||
|
||||
def run():
|
||||
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
||||
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
|
||||
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
|
||||
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
|
||||
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
waiting_list.append(req4)
|
||||
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
|
||||
|
||||
def check_dynamic_forward(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching():
|
||||
spawn(check_dynamic_forward, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching()
|
Loading…
Reference in new issue