mirror of https://github.com/hpcaitech/ColossalAI
[Inference] Dynamic Batching Inference, online and offline (#4953)
* [inference] Dynamic Batching for Single and Multiple GPUs (#4831) * finish batch manager * 1 * first * fix * fix dynamic batching * llama infer * finish test * support different lengths generating * del prints * del prints * fix * fix bug --------- Co-authored-by: CjhHa1 <cjh18671720497outlook.com> * [inference] Async dynamic batching (#4894) * finish input and output logic * add generate * test forward * 1 * [inference]Re push async dynamic batching (#4901) * adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * Revert "[inference]Re push async dynamic batching (#4901)" (#4905) This reverts commitpull/4965/headfbf3c09e67
. * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Revert "[inference] Async dynamic batching (#4894)" (#4909) This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * [infer]Add Ray Distributed Environment Init Scripts (#4911) * Revert "[inference] Async dynamic batching (#4894)" This reverts commitfced140250
. * Add Ray Distributed Environment Init Scripts * support DynamicBatchManager base function * revert _set_tokenizer version * add driver async generate * add async test * fix bugs in test_ray_dist.py * add get_tokenizer.py * fix code style * fix bugs about No module named 'pydantic' in ci test * fix bugs in ci test * fix bugs in ci test * fix bugs in ci test * support dynamic batch for bloom model and is_running function * [Inference]Test for new Async engine (#4935) * infer engine * infer engine * test engine * test engine * new manager * change step * add * test * fix * fix * finish test * finish test * finish test * finish test * add license --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> * add assertion for config (#4947) * [Inference] Finish dynamic batching offline test (#4948) * test * fix test * fix quant * add default * fix * fix some bugs * fix some bugs * fix * fix bug * fix bugs * reset param --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com> Co-authored-by: Cuiqing Li <lixx3527@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
parent
4e4a10c97d
commit
cf579ff46d
@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
|
||||
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
|
||||
|
||||
from .dynamic_batching.io_struct import RequestOutput
|
||||
from .dynamic_batching.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""
|
||||
A class for trace down all the requests, abstraction for async
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._requests
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def add_request(self, request_id: str):
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
self._requests.put_nowait(request_id)
|
||||
self.new_requests_event.set() # NOTE: we may find a better way to clear this event
|
||||
|
||||
def add_stop(self):
|
||||
"""
|
||||
Add a StopIteration flag to stop async generator.
|
||||
"""
|
||||
self._finished_requests.put_nowait(StopIteration)
|
||||
self.new_requests_event.clear()
|
||||
|
||||
def process_request_output(self, request_output: RequestOutput) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
self._finished_requests.put_nowait(request_output)
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
result = await self._finished_requests.get()
|
||||
# print("result of ", result)
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
return result
|
||||
|
||||
|
||||
class Async_Engine:
|
||||
|
||||
"""
|
||||
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
|
||||
Background loop: inference reqs in waiting list (Listen)
|
||||
Request Tracker: manage incoming requests and restore finished ones
|
||||
Generate: exposed func for add new input and return finished ones
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
router_config,
|
||||
engine_config,
|
||||
start_engine_loop: bool = True,
|
||||
) -> None:
|
||||
self.driver = Driver(router_config=router_config, engine_config=engine_config)
|
||||
self.background_loop = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
def _step(self):
|
||||
"""
|
||||
Logic for handling requests
|
||||
"""
|
||||
request_outputs = self.driver.step()
|
||||
if request_outputs is not None:
|
||||
for request_output in request_outputs:
|
||||
self._request_tracker.process_request_output(request_output)
|
||||
self._request_tracker.add_stop()
|
||||
|
||||
def abort_request(self, request_id: str):
|
||||
self.driver.abort(request_id)
|
||||
|
||||
def _has_requests_in_progress(self):
|
||||
return self.driver.is_running()
|
||||
|
||||
async def run_loop_fwd(self):
|
||||
has_requests_in_progress = self._has_requests_in_progress()
|
||||
while True:
|
||||
if not has_requests_in_progress:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
self._step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
@property
|
||||
def is_running(self):
|
||||
return self.background_loop is not None and not self.background_loop.done()
|
||||
|
||||
def start_background_loop(self):
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
|
||||
self._request_tracker.init_event()
|
||||
|
||||
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
|
||||
self.background_loop = asyncio.shield(self.background_loop_unshielded)
|
||||
|
||||
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
|
||||
self.driver.add_input(request_id, prompt, sampling_params)
|
||||
self._request_tracker.add_request(request_id)
|
||||
|
||||
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
|
||||
"""
|
||||
The only exposed func, adding new request and return a async generator that yields the existing results.
|
||||
"""
|
||||
try:
|
||||
if not self.is_running:
|
||||
self.start_background_loop()
|
||||
|
||||
await self.add_request(request_id, prompt, sampling_params)
|
||||
|
||||
async for request_output in self._request_tracker:
|
||||
yield request_output
|
||||
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the request.
|
||||
self.abort_request(request_id)
|
||||
raise e
|
@ -0,0 +1,151 @@
|
||||
from typing import List
|
||||
|
||||
from .dynamic_batching.io_struct import Batch, Req, RequestOutput
|
||||
from .manager import DynamicBatchManager
|
||||
from .tensor_parallel import TPInferEngine
|
||||
|
||||
|
||||
class Async_DynamicBatchManager(DynamicBatchManager):
|
||||
def __init__(
|
||||
self,
|
||||
tp_engine: TPInferEngine,
|
||||
max_total_token_num: int,
|
||||
batch_max_tokens: int,
|
||||
model: str,
|
||||
tokenizer=None,
|
||||
eos_id=None,
|
||||
log_stats=True,
|
||||
log_stats_interval=10,
|
||||
running_batch: Batch = None,
|
||||
waiting_req_list: List = [],
|
||||
):
|
||||
"""
|
||||
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
|
||||
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
|
||||
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
||||
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
||||
eos_id : The end token of a seq
|
||||
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
|
||||
log_stats : whether to log stats
|
||||
log_stats_interval : log stats interval
|
||||
running_batch : running batch
|
||||
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
|
||||
"""
|
||||
super().__init__(
|
||||
tp_engine,
|
||||
max_total_token_num,
|
||||
batch_max_tokens,
|
||||
model,
|
||||
tokenizer,
|
||||
eos_id,
|
||||
log_stats,
|
||||
log_stats_interval,
|
||||
running_batch,
|
||||
waiting_req_list,
|
||||
)
|
||||
|
||||
def _step(self):
|
||||
"""
|
||||
Logic for handling requests
|
||||
"""
|
||||
has_new_finished = False
|
||||
if self.running_batch is None:
|
||||
new_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_batch)
|
||||
self.running_batch = new_batch
|
||||
has_new_finished, outputs = self._prefill_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens = 0
|
||||
|
||||
else:
|
||||
if self.has_wait_tokens < self.max_wait_tokens:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
has_new_finished, outputs = self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
|
||||
else:
|
||||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_mini_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
||||
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
|
||||
if not new_mini_batch.is_clear():
|
||||
self._merge_batch(self.running_batch, new_mini_batch)
|
||||
self.running_batch.merge(new_mini_batch)
|
||||
self.has_wait_tokens = 0
|
||||
|
||||
else:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
has_new_finished, outputs = self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
|
||||
if has_new_finished:
|
||||
return outputs
|
||||
return None
|
||||
|
||||
def _prefill_batch(self, batch):
|
||||
"""
|
||||
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
||||
"""
|
||||
self._init_batch(batch)
|
||||
|
||||
# TODO: figure out if cache and batch id is needed
|
||||
ans = self.engine._prefill_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
||||
outputs = self._handle_finish_req(batch, has_new_finished_req)
|
||||
return has_new_finished_req, outputs
|
||||
# delete finished reqs
|
||||
|
||||
def _decode_batch(self, batch: Batch):
|
||||
"""
|
||||
Decoding process
|
||||
"""
|
||||
ans = self.engine._decode_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
||||
outputs = self._handle_finish_req(batch, has_new_finished_req)
|
||||
return has_new_finished_req, outputs
|
||||
|
||||
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
||||
if has_new_finished_req:
|
||||
finished_reqs = batch.filter_finished()
|
||||
if batch.is_clear():
|
||||
self._remove_batch(batch)
|
||||
else:
|
||||
self._filter_batch(batch)
|
||||
return self._output_process(finished_reqs)
|
||||
return None
|
||||
|
||||
def _output_process(self, finished_reqs: List[Req]):
|
||||
"""
|
||||
Process the output of a batch.
|
||||
"""
|
||||
outputs = []
|
||||
for req in finished_reqs:
|
||||
output = self.tokenizer.decode(req.output_ids)
|
||||
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
|
||||
return outputs
|
||||
|
||||
|
||||
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||
try:
|
||||
batch_manager = Async_DynamicBatchManager(
|
||||
tp_engine=tp_engine,
|
||||
max_total_token_num=args.max_total_token_num,
|
||||
batch_max_tokens=args.batch_max_tokens,
|
||||
eos_id=args.eos_id,
|
||||
model=args.model,
|
||||
log_stats=not args.disable_log_stats,
|
||||
log_stats_interval=args.log_stats_interval,
|
||||
waiting_req_list=waiting_req_list,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise Exception
|
||||
|
||||
return batch_manager
|
@ -0,0 +1,40 @@
|
||||
"""
|
||||
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.
|
||||
|
||||
license: MIT, see LICENSE for more details.
|
||||
"""
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer=None,
|
||||
tokenizer_name: str = "",
|
||||
trust_remote_code: bool = False,
|
||||
use_fast: bool = True,
|
||||
):
|
||||
if tokenizer is not None:
|
||||
tokenizer = tokenizer
|
||||
else:
|
||||
if "llama" in tokenizer_name.lower() and use_fast == True:
|
||||
print(
|
||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||
"take a long time. To eliminate the initialization time, consider "
|
||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||
"tokenizer. This is done automatically in Colossalai."
|
||||
)
|
||||
|
||||
tokenizer_name = _FAST_LLAMA_TOKENIZER
|
||||
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except TypeError:
|
||||
use_fast = False
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
|
||||
)
|
||||
return tokenizer
|
@ -0,0 +1,346 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.inference.tensor_parallel import MemoryManager
|
||||
|
||||
|
||||
# make batch infer state an attr of InferBatch
|
||||
class InferSamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
do_sample: bool = False,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
vocab_size: int = -1,
|
||||
) -> None:
|
||||
self.do_sample = do_sample
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
if self.top_k == -1:
|
||||
self.top_k = vocab_size
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferBatch:
|
||||
batch_id: int
|
||||
requests: List
|
||||
requests_idx_mapping: Dict[int, int]
|
||||
|
||||
input_ids: torch.Tensor
|
||||
|
||||
all_input_ids: List[List[int]]
|
||||
input_lengths: List[int]
|
||||
|
||||
out_token_id_counts: List
|
||||
sampling_param_list: List[InferSamplingParams]
|
||||
|
||||
nopad_total_token_num: int
|
||||
nopad_max_len_in_batch: int
|
||||
nopad_b_loc: torch.Tensor
|
||||
nopad_b_start_loc: torch.Tensor
|
||||
nopad_b_seq_len: torch.Tensor
|
||||
cache_manager: MemoryManager
|
||||
max_total_len: int
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def init_batch(
|
||||
cls,
|
||||
batch_id,
|
||||
requests,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
cache_manager: MemoryManager,
|
||||
vocab_size: int,
|
||||
max_total_len: int,
|
||||
) -> "InferBatch":
|
||||
input_lengths = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
# to avoid memory leak , we pre-allocate 12 more space for each batch.
|
||||
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
|
||||
for i, r in enumerate(requests):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r["request_id"]] = i
|
||||
|
||||
tokenized_input = r["input_id"]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
all_input_ids.append(tokenized_input)
|
||||
out_token_id_counts.append(collections.defaultdict(int))
|
||||
|
||||
# postprocessor
|
||||
sampling_param = r["sampling_param"]
|
||||
sampling_param["vocab_size"] = vocab_size
|
||||
sampling_param_list.append(InferSamplingParams(**sampling_param))
|
||||
|
||||
nopad_total_token_num += input_length
|
||||
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
|
||||
|
||||
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
|
||||
if len(requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
|
||||
# Create tensors on device
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
|
||||
return cls(
|
||||
batch_id=batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def free_self(self) -> None:
|
||||
"""
|
||||
Free the memory of the InferBatch itself
|
||||
"""
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
@torch.no_grad()
|
||||
def filter(self, request_ids: List[int]) -> "InferBatch":
|
||||
"""
|
||||
Filter finished batch and return a new InferBatch with left ones.
|
||||
"""
|
||||
if len(request_ids) == 0:
|
||||
raise ValueError("Batch must have at least one request")
|
||||
if len(request_ids) == len(self):
|
||||
return self
|
||||
requests_idx_mapping = {}
|
||||
indices = []
|
||||
requests = []
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
nopad_total_token_num = 0
|
||||
nopad_max_len_in_batch = 0
|
||||
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
|
||||
|
||||
left_idx = []
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
left_idx.append(idx)
|
||||
|
||||
left_idx_set = set(left_idx)
|
||||
remove_index = []
|
||||
for idx in range(len(self)):
|
||||
if idx not in left_idx_set:
|
||||
remove_index.append(
|
||||
self.nopad_b_loc[
|
||||
idx,
|
||||
(self.nopad_max_len_in_batch - 1)
|
||||
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
)
|
||||
remove_index = torch.cat(remove_index, dim=-1)
|
||||
self.cache_manager.free(remove_index)
|
||||
|
||||
nopad_max_len_in_batch = 0
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
indices.append(idx)
|
||||
|
||||
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
|
||||
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
|
||||
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
|
||||
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
|
||||
|
||||
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
|
||||
indices,
|
||||
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
|
||||
]
|
||||
for i, request_id in enumerate(request_ids):
|
||||
idx = self.requests_idx_mapping[request_id]
|
||||
requests_idx_mapping[request_id] = i
|
||||
requests.append(self.requests[idx])
|
||||
all_input_ids.append(self.all_input_ids[idx])
|
||||
input_lengths.append(self.input_lengths[idx])
|
||||
|
||||
input_ids = self.input_ids[indices]
|
||||
|
||||
return InferBatch(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
|
||||
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
|
||||
cache_manager=self.cache_manager,
|
||||
max_total_len=self.max_total_len,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@torch.no_grad()
|
||||
def merge(cls, batch1, batch2) -> "InferBatch":
|
||||
"""
|
||||
Return megerd new InferBatch
|
||||
"""
|
||||
requests = batch1.requests + batch2.requests
|
||||
requests_idx_mapping = {}
|
||||
new_batch_size = len(batch1) + len(batch2)
|
||||
|
||||
input_ids = batch1.input_ids.new_empty(new_batch_size)
|
||||
all_input_ids = []
|
||||
input_lengths = []
|
||||
out_token_id_counts = []
|
||||
sampling_param_list = []
|
||||
|
||||
cumulative_batch_size = 0
|
||||
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
|
||||
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
|
||||
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
|
||||
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
|
||||
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
|
||||
nopad_start_loc_len_temp = 0
|
||||
batches = [batch1, batch2]
|
||||
for i, batch in enumerate(batches):
|
||||
if i == 0:
|
||||
requests_idx_mapping = batch.requests_idx_mapping
|
||||
else:
|
||||
for k, v in batch.requests_idx_mapping.items():
|
||||
requests_idx_mapping[k] = v + cumulative_batch_size
|
||||
start_index = cumulative_batch_size
|
||||
end_index = cumulative_batch_size + len(batch)
|
||||
input_ids[start_index:end_index] = batch.input_ids
|
||||
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
|
||||
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
|
||||
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
|
||||
nopad_b_loc[
|
||||
start_index:end_index,
|
||||
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
|
||||
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
|
||||
|
||||
all_input_ids.extend(batch.all_input_ids)
|
||||
|
||||
input_lengths.extend(batch.input_lengths)
|
||||
out_token_id_counts.extend(batch.out_token_id_counts)
|
||||
sampling_param_list.extend(batch.sampling_param_list)
|
||||
# Update
|
||||
cumulative_batch_size += len(batch)
|
||||
|
||||
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
|
||||
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
|
||||
)
|
||||
return InferBatch(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
input_lengths=input_lengths,
|
||||
all_input_ids=all_input_ids,
|
||||
nopad_total_token_num=nopad_total_token_num,
|
||||
nopad_max_len_in_batch=nopad_max_len_in_batch,
|
||||
nopad_b_loc=nopad_b_loc,
|
||||
nopad_b_start_loc=nopad_b_start_loc,
|
||||
nopad_b_seq_len=nopad_b_seq_len,
|
||||
out_token_id_counts=out_token_id_counts,
|
||||
sampling_param_list=sampling_param_list,
|
||||
cache_manager=batches[0].cache_manager,
|
||||
max_total_len=max_total_len,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
|
||||
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
temperatures: List[float] = []
|
||||
top_ps: List[float] = []
|
||||
top_ks: List[int] = []
|
||||
p_token_ids: List[int] = []
|
||||
p_token_counts: List[int] = []
|
||||
p_seq_len: List[int] = [
|
||||
0,
|
||||
]
|
||||
p_max_len_in_batch: int = 0
|
||||
for i, id_to_count in enumerate(self.out_token_id_counts):
|
||||
sample_param = self.sampling_param_list[i]
|
||||
presence_penalties.append(sample_param.presence_penalty)
|
||||
frequency_penalties.append(sample_param.frequency_penalty)
|
||||
temperatures.append(sample_param.temperature)
|
||||
top_ps.append(sample_param.top_p)
|
||||
top_ks.append(sample_param.top_k)
|
||||
|
||||
for token_id, count in id_to_count.items():
|
||||
p_token_ids.append(token_id)
|
||||
p_token_counts.append(count)
|
||||
p_seq_len.append(len(id_to_count))
|
||||
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
|
||||
|
||||
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
|
||||
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
|
||||
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
|
||||
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
|
||||
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
|
||||
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
|
||||
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
|
||||
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
|
||||
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
|
||||
return (
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
temperatures,
|
||||
top_ps,
|
||||
top_ks,
|
||||
p_token_ids,
|
||||
p_token_counts,
|
||||
p_cumsum_seq_len,
|
||||
p_max_len_in_batch,
|
||||
)
|
@ -0,0 +1,166 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from .sampling_params import SamplingParams
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
|
||||
self.request_id = request_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.input_len = len(prompt_ids)
|
||||
self.max_output_len = sample_params.max_new_tokens
|
||||
self.sample_params = sample_params
|
||||
self.output_ids = []
|
||||
self.output_metadata_list = []
|
||||
self.has_generate_finished = False
|
||||
self.aborted = False
|
||||
self.prompts = prompts
|
||||
|
||||
def to_rpc_obj(self):
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"input_id": self.prompt_ids,
|
||||
"output_len": self.max_output_len,
|
||||
"sampling_param": self.sample_params.to_dict(),
|
||||
}
|
||||
|
||||
def stop_sequences_matched(self):
|
||||
# should we add stpp sequences to the sample params?
|
||||
if self.sample_params.stop_sequences is not None:
|
||||
for stop_token_ids in self.sample_params.stop_sequences:
|
||||
stop_len = len(stop_token_ids)
|
||||
if (
|
||||
stop_len > 0
|
||||
and len(self.output_ids) >= stop_len
|
||||
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
|
||||
|
||||
|
||||
class Batch:
|
||||
def __init__(self, batch_id, reqs: List[Req]):
|
||||
self.batch_id = batch_id
|
||||
self.reqs = reqs
|
||||
self.id_to_reqs = {req.request_id: req for req in reqs}
|
||||
|
||||
def input_tokens(self):
|
||||
batch_input_tokens = 0
|
||||
for req in self.reqs:
|
||||
batch_input_tokens += req.input_len
|
||||
return batch_input_tokens
|
||||
|
||||
def calcu_max_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + req.max_output_len
|
||||
return tokens
|
||||
|
||||
def calcu_used_tokens(self):
|
||||
tokens = 0
|
||||
for req in self.reqs:
|
||||
tokens += req.input_len + len(req.output_ids)
|
||||
return tokens
|
||||
|
||||
def mark_finished_req(self, eos_id, engine_max_output_len):
|
||||
has_new_finish = False
|
||||
for req in self.reqs:
|
||||
if req.stop_sequences_matched():
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if len(req.output_ids) >= engine_max_output_len:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
if len(req.output_ids) >= req.max_output_len or req.aborted:
|
||||
req.has_generate_finished = True
|
||||
has_new_finish = True
|
||||
return has_new_finish
|
||||
|
||||
def filter_finished(self) -> List[Req]:
|
||||
"""
|
||||
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
|
||||
"""
|
||||
# TODO: the logic of return should be defined here.
|
||||
unfinished_req = []
|
||||
finished_req = []
|
||||
for req in self.reqs:
|
||||
if not req.has_generate_finished:
|
||||
unfinished_req.append(req)
|
||||
else:
|
||||
finished_req.append(req)
|
||||
self.reqs = unfinished_req
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return finished_req
|
||||
|
||||
def is_clear(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def merge(self, mini_batch):
|
||||
for _req in mini_batch.reqs:
|
||||
self.reqs.append(_req)
|
||||
self.id_to_reqs = {req.request_id: req for req in self.reqs}
|
||||
return
|
||||
|
||||
def __repr__(self):
|
||||
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
|
||||
|
||||
def __len__(self):
|
||||
return len(self.reqs)
|
||||
|
||||
|
||||
class BatchTokenIdOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, int, Dict, bool, bool]
|
||||
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class BatchStrOut:
|
||||
def __init__(self):
|
||||
self.reqs_infs: List[
|
||||
Tuple[str, str, Dict, bool, bool]
|
||||
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
|
||||
|
||||
|
||||
class AbortReq:
|
||||
def __init__(self, req_id):
|
||||
self.req_id = req_id
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
"""The output data of a request to the LLM.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
outputs: The output sequences of the request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.outputs = outputs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"outputs={self.outputs}, "
|
||||
)
|
@ -0,0 +1,152 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import ray
|
||||
import ray.util.collective as collective
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.async_manager import start_dynamic_batching
|
||||
from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer
|
||||
from colossalai.inference.dynamic_batching.io_struct import RequestOutput
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.tensor_parallel.engine import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import free_port
|
||||
|
||||
ray_serve_logger = logging.getLogger("ray.serve")
|
||||
|
||||
|
||||
def log_cuda_info(scope_name: str):
|
||||
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
|
||||
ray_serve_logger.info(
|
||||
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
ray_serve_logger.info(
|
||||
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
|
||||
)
|
||||
else:
|
||||
ray_serve_logger.info(f" {scope_name}: cuda is not available!")
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tensor_parallel_size: int,
|
||||
max_batch_size: int,
|
||||
max_input_len: int,
|
||||
max_output_len: int,
|
||||
router_config: RooterArgsClass,
|
||||
):
|
||||
log_cuda_info("Worker.init")
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.model_path = model_path
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_input_len = max_input_len
|
||||
self.max_output_len = max_output_len
|
||||
self.router_config = router_config
|
||||
|
||||
def setup(self, world_size, rank, port):
|
||||
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
|
||||
collective.init_collective_group(world_size, rank, "nccl", "default")
|
||||
# initialize and set distributed environment
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
|
||||
log_cuda_info("Worker.setup")
|
||||
|
||||
# Load model
|
||||
self.tokenizer = get_tokenizer(tokenizer_name=self.model_path)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
|
||||
)
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
|
||||
self.infer_engine = TPInferEngine(
|
||||
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
|
||||
)
|
||||
self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, [])
|
||||
|
||||
return True
|
||||
|
||||
# def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]:
|
||||
# ray_serve_logger.info(f"text: {prompt}")
|
||||
|
||||
# final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# return final_outputs
|
||||
|
||||
def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
|
||||
self.start_dynamic_batching.add_input(request_id, prompt, sampling_params)
|
||||
|
||||
def abort(self, request_id: str):
|
||||
self.start_dynamic_batching.abort(request_id)
|
||||
|
||||
def step(self) -> List[RequestOutput]:
|
||||
return self.start_dynamic_batching._step()
|
||||
|
||||
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
|
||||
self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt)
|
||||
|
||||
def is_running(self):
|
||||
return self.start_dynamic_batching.is_running()
|
||||
|
||||
|
||||
class Driver:
|
||||
def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass):
|
||||
log_cuda_info("Driver:init")
|
||||
model_path = engine_config.model
|
||||
tensor_parallel_size = engine_config.tensor_parallel_size
|
||||
|
||||
self.num_workers = tensor_parallel_size
|
||||
self.workers = []
|
||||
init_rets = []
|
||||
|
||||
# Just grab a free port on localhost
|
||||
# NOTE workers in this communication group listen to the same port
|
||||
available_port = free_port()
|
||||
|
||||
for i in range(self.num_workers):
|
||||
worker_name = "worker_idx_{}".format(i)
|
||||
w = Worker.options(name=worker_name).remote(
|
||||
model_path,
|
||||
self.num_workers,
|
||||
engine_config.max_batch_size,
|
||||
engine_config.max_input_len,
|
||||
engine_config.max_output_len,
|
||||
router_config,
|
||||
)
|
||||
self.workers.append(w)
|
||||
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
|
||||
_options = {
|
||||
"group_name": "default_driver",
|
||||
"world_size": self.num_workers,
|
||||
"ranks": [i for i in range(self.num_workers)],
|
||||
"backend": "nccl",
|
||||
}
|
||||
collective.create_collective_group(self.workers, **_options)
|
||||
_ = ray.get(init_rets)
|
||||
|
||||
def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
|
||||
ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers])
|
||||
|
||||
def abort(self, request_id: str):
|
||||
ray.get([w.abort.remote(request_id) for w in self.workers])
|
||||
|
||||
def step(self):
|
||||
results = ray.get([w.step.remote() for w in self.workers])
|
||||
outputs = results[0] # get any one of the copies
|
||||
return outputs
|
||||
|
||||
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str):
|
||||
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])
|
||||
|
||||
def is_running(self):
|
||||
results = ray.get([w.is_running.remote() for w in self.workers])
|
||||
return any(results)
|
@ -0,0 +1,58 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EngineArgsClass(BaseModel):
|
||||
"""Config for Engine"""
|
||||
|
||||
model: str
|
||||
tensor_parallel_size: int = 2
|
||||
max_batch_size: int = 4
|
||||
max_input_len: int = 128
|
||||
max_output_len: int = 32
|
||||
|
||||
|
||||
class RooterArgsClass(BaseModel):
|
||||
"""Config for Rooter"""
|
||||
|
||||
max_total_token_num: int = 42
|
||||
batch_max_tokens: int = 42
|
||||
eos_id: int = 0
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
model: str
|
||||
|
||||
|
||||
class RayInitConfig(BaseModel):
|
||||
"""All-together configs without app router config"""
|
||||
|
||||
engine_config_data: EngineArgsClass
|
||||
router_config_data: RooterArgsClass
|
||||
|
||||
@classmethod
|
||||
def from_yaml_path(cls, path: str):
|
||||
try:
|
||||
with open(path, "r") as yaml_file:
|
||||
try:
|
||||
config = yaml.safe_load(yaml_file)
|
||||
# serve deployment config
|
||||
engine_config = config.get("engine_config", {})
|
||||
router_config = config.get("router_config", {})
|
||||
|
||||
return cls(
|
||||
engine_config_data=engine_config,
|
||||
router_config_data=router_config,
|
||||
)
|
||||
except yaml.YAMLError as e:
|
||||
logger.error(f"An Error occurred when parsing yaml: {e}")
|
||||
raise
|
||||
except FileNotFoundError:
|
||||
logger.error(f"The file '{path}' does not exist!")
|
||||
raise
|
||||
except OSError as e:
|
||||
logger.error(f"An Error occurred: {e}")
|
||||
raise
|
@ -0,0 +1,73 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .io_struct import Batch, Req
|
||||
|
||||
|
||||
class ReqQueue:
|
||||
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
|
||||
self.max_total_tokens = max_total_tokens
|
||||
assert batch_max_tokens is not None
|
||||
self.batch_max_tokens = batch_max_tokens
|
||||
self.running_max_req_size = running_max_req_size
|
||||
self.waiting_req_list: List[Req] = waiting_req_list
|
||||
|
||||
def append(self, req):
|
||||
self.waiting_req_list.append(req)
|
||||
return
|
||||
|
||||
def _init_cache_list(self, current_batch: Batch):
|
||||
if current_batch is not None:
|
||||
self.cache_len_list = [
|
||||
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
|
||||
for req in current_batch.reqs
|
||||
]
|
||||
else:
|
||||
self.cache_len_list = []
|
||||
|
||||
# @calculate_time(show=True, min_cost_ms=0.1)
|
||||
def _can_add_new_req(self, req):
|
||||
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
|
||||
self.cache_len_list.sort(key=lambda x: -x[1])
|
||||
|
||||
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
|
||||
# assert left_out_len_array.min() >= 0
|
||||
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
|
||||
cum_run_len_array = np.cumsum(has_run_len_array)
|
||||
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
|
||||
|
||||
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
|
||||
# NOTE: change here < to <=
|
||||
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
|
||||
|
||||
def generate_new_batch(self, current_batch: Batch = None):
|
||||
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
|
||||
return None
|
||||
self._init_cache_list(current_batch)
|
||||
can_run_list = []
|
||||
new_batch_total_tokens = 0
|
||||
aborted_count = 0
|
||||
for req in self.waiting_req_list:
|
||||
flag = self._can_add_new_req(req)
|
||||
if req.aborted:
|
||||
aborted_count += 1
|
||||
continue
|
||||
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
|
||||
can_run_list.append(req)
|
||||
new_batch_total_tokens += req.input_len
|
||||
else:
|
||||
break
|
||||
|
||||
if len(can_run_list) != 0:
|
||||
new_batch = Batch(uuid.uuid4().hex, can_run_list)
|
||||
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
|
||||
return new_batch
|
||||
else:
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return self.waiting_req_list.__len__()
|
@ -0,0 +1,83 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
"""Sampling parameters for text generation."""
|
||||
from typing import List, Optional, Union
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
def __init__(
|
||||
self,
|
||||
do_sample: bool = False,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1, # -1 is for all
|
||||
ignore_eos: bool = False,
|
||||
max_new_tokens: int = 256,
|
||||
stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation
|
||||
) -> None:
|
||||
self.do_sample = do_sample
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.stop_sequences = stop_sequences
|
||||
if self.do_sample == False:
|
||||
self.temperature = 1.0
|
||||
self.top_p = 1.0
|
||||
self.top_k = 1
|
||||
if (
|
||||
self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS
|
||||
): # temperature is too slow, change to greedy search
|
||||
self.temperature = 1.0
|
||||
self.top_k = 1
|
||||
return
|
||||
|
||||
def verify(self):
|
||||
if self.presence_penalty < 0.0:
|
||||
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
|
||||
if self.frequency_penalty < 0.0:
|
||||
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
|
||||
if self.temperature <= 0.0:
|
||||
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
|
||||
if self.top_p <= 0.0 or self.top_p > 1.0:
|
||||
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
|
||||
if self.top_k < -1 or self.top_k == 0:
|
||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
|
||||
if self.max_new_tokens < 1:
|
||||
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
|
||||
return
|
||||
|
||||
def stop_sentences_to_token_ids(self, tokenizer):
|
||||
if self.stop_sequences is None:
|
||||
self.stop_sequences = []
|
||||
else:
|
||||
if isinstance(self.stop_sequences, str):
|
||||
self.stop_sequences = [self.stop_sequences]
|
||||
new_stop_sequences = []
|
||||
for stop_str in self.stop_sequences:
|
||||
stop_str_ids = tokenizer.encode(stop_str)
|
||||
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
|
||||
stop_str_ids = stop_str_ids[1:]
|
||||
if len(stop_str_ids) > 0:
|
||||
new_stop_sequences.append(stop_str_ids)
|
||||
self.stop_sequences = new_stop_sequences
|
||||
return
|
||||
|
||||
def to_dict(self):
|
||||
ret = {}
|
||||
ret["do_sample"] = self.do_sample
|
||||
ret["presence_penalty"] = self.presence_penalty
|
||||
ret["frequency_penalty"] = self.frequency_penalty
|
||||
ret["temperature"] = self.temperature
|
||||
ret["top_p"] = self.top_p
|
||||
ret["top_k"] = self.top_k
|
||||
# if self.ignore_eos is not None:
|
||||
# ret["ignore_eos"] = self.ignore_eos
|
||||
return ret
|
@ -0,0 +1,45 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Stats:
|
||||
def __init__(self, log_status, log_stats_interval) -> None:
|
||||
self.log_stats = log_status
|
||||
self.log_stats_interval = log_stats_interval
|
||||
self.last_log_time = time.time()
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
return
|
||||
|
||||
def count_prompt_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = run_batch.input_tokens()
|
||||
self.prompt_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def count_output_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = len(run_batch.reqs)
|
||||
self.output_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def print_stats(self):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self.last_log_time > self.log_stats_interval:
|
||||
print(
|
||||
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
|
||||
)
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
self.last_log_time = now
|
||||
return
|
@ -0,0 +1,296 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from .dynamic_batching.get_tokenizer import get_tokenizer
|
||||
from .dynamic_batching.infer_batch import InferBatch
|
||||
from .dynamic_batching.io_struct import Batch, Req
|
||||
from .dynamic_batching.req_queue import ReqQueue
|
||||
from .dynamic_batching.sampling_params import SamplingParams
|
||||
from .dynamic_batching.stats import Stats
|
||||
from .tensor_parallel import TPInferEngine
|
||||
|
||||
|
||||
class DynamicBatchManager:
|
||||
def __init__(
|
||||
self,
|
||||
tp_engine: TPInferEngine,
|
||||
max_total_token_num,
|
||||
batch_max_tokens,
|
||||
model,
|
||||
tokenizer=None,
|
||||
eos_id=None,
|
||||
log_stats=True,
|
||||
log_stats_interval=10,
|
||||
running_batch: Batch = None,
|
||||
waiting_req_list: List = [],
|
||||
):
|
||||
"""
|
||||
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
|
||||
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
|
||||
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
||||
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
||||
eos_id : The end token of a seq
|
||||
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
|
||||
log_stats : whether to log stats
|
||||
log_stats_interval : log stats interval
|
||||
running_batch : running batch
|
||||
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
|
||||
"""
|
||||
self.engine = tp_engine
|
||||
self.max_total_token_num = max_total_token_num
|
||||
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
|
||||
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
|
||||
# all the inputs should be put into req_queue: waiting req list
|
||||
assert max_total_token_num >= self.engine.max_batch_size * (
|
||||
self.engine.max_input_len + self.engine.max_output_len
|
||||
), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)"
|
||||
assert (
|
||||
batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len
|
||||
), "batch_max_tokens should be greater than (max_input_len+max_output_len)"
|
||||
self.running_batch: Batch = running_batch
|
||||
self.eos_id = eos_id
|
||||
self.has_wait_tokens = 0
|
||||
self.max_wait_tokens = 10
|
||||
self.model = model
|
||||
|
||||
self.stats_tool = Stats(log_stats, log_stats_interval)
|
||||
self.mem_usage_interval = log_stats_interval * 2
|
||||
self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer
|
||||
if self.eos_id == None:
|
||||
self.eos_id = self.tokenizer.eos_token_id
|
||||
|
||||
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
|
||||
"""
|
||||
Add new request to req queue, during initialization all requests are held in waiting list.
|
||||
"""
|
||||
sampling_params.max_new_tokens = (
|
||||
self.engine.max_output_len
|
||||
if sampling_params.max_new_tokens > self.engine.max_output_len
|
||||
else sampling_params.max_new_tokens
|
||||
)
|
||||
req = Req(request_id, prompt_ids, sampling_params, prompts)
|
||||
self.req_queue.append(req)
|
||||
return
|
||||
|
||||
def add_input(self, request_id, prompts, sampling_params):
|
||||
"""
|
||||
Encode and Add new input to req queue. support one sequence input for now.
|
||||
"""
|
||||
prompt_ids = self.tokenizer.encode(prompts)
|
||||
prompt_len = len(prompt_ids)
|
||||
if prompt_len > self.engine.max_input_len:
|
||||
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
|
||||
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
|
||||
self.add_req(request_id, prompt_ids, sampling_params, prompts)
|
||||
return
|
||||
|
||||
def abort(self, request_id):
|
||||
if self.running_batch is not None:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.request_id == request_id:
|
||||
req.has_generate_finished = True
|
||||
req.aborted = True
|
||||
for req in self.req_queue.waiting_req_list:
|
||||
if req.request_id == request_id:
|
||||
req.has_generate_finished = True
|
||||
req.aborted = True
|
||||
return
|
||||
|
||||
def loop_for_fwd(self):
|
||||
"""
|
||||
The main loop for a dynamic batching process.
|
||||
"""
|
||||
counter_count = 0
|
||||
# self.running_batch is not None or self.req_queue.waiting_req_list
|
||||
while self.running_batch is not None or self.req_queue.waiting_req_list:
|
||||
yield from self._step()
|
||||
counter_count += 1
|
||||
if self.running_batch is not None:
|
||||
if counter_count % self.mem_usage_interval == 0:
|
||||
print(
|
||||
"current batch size:",
|
||||
len(self.running_batch.reqs),
|
||||
"token used ratio:",
|
||||
self.running_batch.calcu_used_tokens() / self.max_total_token_num,
|
||||
)
|
||||
self.stats_tool.print_stats()
|
||||
|
||||
if self.running_batch is None:
|
||||
time.sleep(0.1) # 10ms
|
||||
|
||||
def _step(self):
|
||||
"""
|
||||
Logic for handling requests
|
||||
"""
|
||||
|
||||
if self.running_batch is None:
|
||||
new_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_batch)
|
||||
self.running_batch = new_batch
|
||||
yield from self._prefill_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens = 0
|
||||
return
|
||||
|
||||
if self.has_wait_tokens < self.max_wait_tokens:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
yield from self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
return
|
||||
else:
|
||||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||
if new_mini_batch is not None:
|
||||
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
||||
yield from self._prefill_batch(new_mini_batch)
|
||||
if not new_mini_batch.is_clear():
|
||||
self._merge_batch(self.running_batch, new_mini_batch)
|
||||
self.running_batch.merge(new_mini_batch)
|
||||
self.has_wait_tokens = 0
|
||||
|
||||
else:
|
||||
self.stats_tool.count_output_tokens(self.running_batch)
|
||||
yield from self._decode_batch(self.running_batch)
|
||||
self._filter_runing_batch()
|
||||
self.has_wait_tokens += 1
|
||||
|
||||
return
|
||||
|
||||
def _init_batch(self, batch: Batch, dtype="fp16"):
|
||||
reqs = [r.to_rpc_obj() for r in batch.reqs]
|
||||
batch_id = batch.batch_id
|
||||
|
||||
import torch
|
||||
|
||||
if dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
else:
|
||||
assert False, "error dtype"
|
||||
|
||||
batch_data = InferBatch.init_batch(
|
||||
batch_id,
|
||||
reqs,
|
||||
dtype,
|
||||
torch.cuda.current_device(),
|
||||
self.engine.cache_manager,
|
||||
self.engine.model.config.vocab_size,
|
||||
self.engine.max_input_len + self.engine.max_output_len,
|
||||
)
|
||||
self.engine.cache[batch_id] = batch_data
|
||||
|
||||
def _prefill_batch(self, batch):
|
||||
"""
|
||||
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
||||
"""
|
||||
self._init_batch(batch)
|
||||
|
||||
# TODO: figure out if cache and batch id is needed
|
||||
ans = self.engine._prefill_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
||||
|
||||
# delete finished reqs
|
||||
|
||||
def _decode_batch(self, batch: Batch):
|
||||
"""
|
||||
Decoding process
|
||||
"""
|
||||
ans = self.engine._decode_batch(batch.batch_id)
|
||||
req_to_out_token_id = ans
|
||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
|
||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
||||
|
||||
def _filter_batch(self, batch: Batch):
|
||||
batch_id = batch.batch_id
|
||||
req_id_list = [r.request_id for r in batch.reqs]
|
||||
batch = self.engine.cache.pop(batch_id)
|
||||
filter_batch = batch.filter(req_id_list)
|
||||
del batch
|
||||
self.engine.cache[batch_id] = filter_batch
|
||||
|
||||
def _merge_batch(self, batch1, batch2):
|
||||
"""
|
||||
Merge new mini batch into running batch.
|
||||
"""
|
||||
batch1 = self.engine.cache.pop(batch1.batch_id)
|
||||
batch2 = self.engine.cache.pop(batch2.batch_id)
|
||||
|
||||
m_batch = InferBatch.merge(batch1, batch2)
|
||||
self.engine.cache[batch1.batch_id] = m_batch
|
||||
del batch1
|
||||
del batch2
|
||||
|
||||
def _remove_batch(self, batch):
|
||||
"""
|
||||
Remove finished batch.
|
||||
"""
|
||||
batch = self.engine.cache.pop(batch.batch_id)
|
||||
batch.free_self()
|
||||
del batch
|
||||
|
||||
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
||||
if has_new_finished_req:
|
||||
finished_reqs = batch.filter_finished()
|
||||
if batch.is_clear():
|
||||
self._remove_batch(batch)
|
||||
else:
|
||||
self._filter_batch(batch)
|
||||
yield from self._output_process(finished_reqs)
|
||||
|
||||
def _filter_runing_batch(self):
|
||||
if self.running_batch is not None and self.running_batch.is_clear():
|
||||
self.running_batch = None
|
||||
|
||||
def _add_token_id_to_req(self, batch: Batch, req_ans):
|
||||
for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
|
||||
req = batch.id_to_reqs[req_id]
|
||||
req.output_ids.append(new_token_id)
|
||||
req.output_metadata_list.append(new_gen_metadata)
|
||||
return
|
||||
|
||||
def _output_process(self, finished_reqs: List[Req]):
|
||||
"""
|
||||
Process the output of a batch.
|
||||
"""
|
||||
for req in finished_reqs:
|
||||
output = self.tokenizer.decode(req.output_ids)
|
||||
yield req.prompts + output
|
||||
|
||||
def clean_up(self):
|
||||
# this logic should be implemented in the future.
|
||||
pass
|
||||
|
||||
def generate(self, request_id, prompts, sampling_params):
|
||||
"""
|
||||
Generate the output of a request.
|
||||
"""
|
||||
self.add_input(request_id, prompts, sampling_params)
|
||||
return self.loop_for_fwd()
|
||||
|
||||
def is_running(self):
|
||||
return self.running_batch is not None or self.req_queue.waiting_req_list
|
||||
|
||||
|
||||
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||
try:
|
||||
batch_manager = DynamicBatchManager(
|
||||
tp_engine=tp_engine,
|
||||
max_total_token_num=args.max_total_token_num,
|
||||
batch_max_tokens=args.batch_max_tokens,
|
||||
eos_id=args.eos_id,
|
||||
model=args.model,
|
||||
log_stats=not args.disable_log_stats,
|
||||
log_stats_interval=args.log_stats_interval,
|
||||
waiting_req_list=waiting_req_list,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise Exception
|
||||
|
||||
return batch_manager
|
@ -0,0 +1,14 @@
|
||||
engine_config:
|
||||
model: MODEL_PATH
|
||||
tensor_parallel_size: 1
|
||||
max_batch_size: 2
|
||||
max_input_len: 1024
|
||||
max_output_len: 512
|
||||
# config for app router deployment
|
||||
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
|
||||
router_config:
|
||||
max_total_token_num: 4096
|
||||
batch_max_tokens: 4096
|
||||
disable_log_stats: False
|
||||
log_stats_interval: 10
|
||||
model: MODEL_PATH
|
@ -0,0 +1,61 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.async_engine import Async_Engine
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_async_engine(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
|
||||
prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
|
||||
sampling_params = SamplingParams()
|
||||
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))
|
||||
|
||||
|
||||
async def get_result(engine, prompt, sampling_params):
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
results = engine.generate(request_id, prompt, sampling_params)
|
||||
async for result in results:
|
||||
# print(result)
|
||||
assert result is not None
|
||||
|
||||
|
||||
async def asy_for_loop_test(config, prompt, sampling_params):
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
engine = Async_Engine(router_config=router_config, engine_config=engine_config)
|
||||
for i in range(10):
|
||||
print("in for loop", i)
|
||||
await get_result(engine, prompt, sampling_params)
|
||||
|
||||
|
||||
def check_async_engine(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_async_engine(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_async_engine():
|
||||
spawn(check_async_engine, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_async_engine()
|
@ -0,0 +1,95 @@
|
||||
import pytest
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import DynamicBatchManager
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 48
|
||||
MAX_OUTPUT_LEN = 256
|
||||
|
||||
|
||||
def run():
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [1], sampling_params)
|
||||
req2 = Req(1, [2], sampling_params)
|
||||
req3 = Req(2, [3], sampling_params)
|
||||
# req 1-3 are initiliazed as token forward requests
|
||||
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
|
||||
# init model and tp engine
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
|
||||
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
|
||||
dynamic_batch_manager = DynamicBatchManager(
|
||||
tp_engine=infer_engine,
|
||||
max_total_token_num=640,
|
||||
batch_max_tokens=608,
|
||||
eos_id=0,
|
||||
log_stats=False,
|
||||
log_stats_interval=10,
|
||||
waiting_req_list=waiting_list,
|
||||
model="llama",
|
||||
)
|
||||
before_add = len(dynamic_batch_manager.req_queue)
|
||||
|
||||
# test add req function
|
||||
dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params)
|
||||
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
|
||||
|
||||
# test abort function
|
||||
dynamic_batch_manager.abort(req4.request_id)
|
||||
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
|
||||
|
||||
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
|
||||
batch = dynamic_batch_manager.req_queue.generate_new_batch()
|
||||
assert len(batch) == 2
|
||||
|
||||
dynamic_batch_manager._init_batch(batch)
|
||||
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
|
||||
|
||||
batch.reqs[0].has_generate_finished = True
|
||||
# filter one finished
|
||||
batch.filter_finished()
|
||||
dynamic_batch_manager._filter_batch(batch)
|
||||
assert len(dynamic_batch_manager.engine.cache) == 1
|
||||
|
||||
# test merge batch
|
||||
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
|
||||
assert len(new_batch) == 1
|
||||
dynamic_batch_manager._init_batch(new_batch)
|
||||
dynamic_batch_manager._merge_batch(batch, new_batch)
|
||||
|
||||
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
|
||||
|
||||
|
||||
def check_dynamic_batching_manager(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching_manager():
|
||||
spawn(check_dynamic_batching_manager, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching_manager()
|
@ -0,0 +1,84 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.inference.manager import start_dynamic_batching
|
||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
TP_SIZE = 1
|
||||
MAX_BATCH_SIZE = 2
|
||||
MAX_INPUT_LEN = 5
|
||||
MAX_OUTPUT_LEN = 16
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||
|
||||
|
||||
@dataclass
|
||||
class args:
|
||||
max_total_token_num: int
|
||||
batch_max_tokens: int
|
||||
model: str
|
||||
eos_id: int
|
||||
disable_log_stats: bool
|
||||
log_stats_interval: int
|
||||
|
||||
|
||||
def run():
|
||||
arg = args(
|
||||
max_total_token_num=42,
|
||||
model="llama",
|
||||
batch_max_tokens=42,
|
||||
eos_id=0,
|
||||
disable_log_stats=False,
|
||||
log_stats_interval=10,
|
||||
)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
||||
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
|
||||
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
|
||||
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
|
||||
|
||||
waiting_list = []
|
||||
waiting_list.append(req1)
|
||||
waiting_list.append(req2)
|
||||
waiting_list.append(req3)
|
||||
waiting_list.append(req4)
|
||||
|
||||
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
|
||||
model = LlamaForCausalLM(llama_config)
|
||||
model = model.half()
|
||||
|
||||
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
|
||||
|
||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||
|
||||
ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
|
||||
for result in ans_gen:
|
||||
assert result is not None
|
||||
|
||||
|
||||
def check_dynamic_forward(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_dynamic_batching():
|
||||
spawn(check_dynamic_forward, TP_SIZE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dynamic_batching()
|
@ -0,0 +1,66 @@
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
|
||||
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
|
||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
PATH = "config.yaml"
|
||||
|
||||
|
||||
def run_ray_dist(path: str):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
config = RayInitConfig.from_yaml_path(path)
|
||||
router_config = config.router_config_data
|
||||
engine_config = config.engine_config_data
|
||||
model = engine_config.model
|
||||
if model is None or not os.path.exists(model):
|
||||
return
|
||||
driver = Driver(router_config=router_config, engine_config=engine_config)
|
||||
prompt = "Introduce some landmarks in Beijing"
|
||||
|
||||
request_id = str(uuid.uuid4().hex)
|
||||
sampling_params = SamplingParams()
|
||||
print("sampling_params: ", sampling_params)
|
||||
|
||||
async def get_result(request_id, prompt, sampling_params):
|
||||
return await driver.async_generate(request_id, prompt, sampling_params)
|
||||
|
||||
for test_async in [True, False]:
|
||||
if test_async:
|
||||
print("test_async: ", test_async)
|
||||
result = asyncio.run(get_result(request_id, prompt, sampling_params))
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
else:
|
||||
print("test_async: ", test_async)
|
||||
result = driver.generate(request_id, prompt, sampling_params)
|
||||
assert result is not None
|
||||
print("result: ", result)
|
||||
|
||||
is_running = None
|
||||
is_running = driver.is_running()
|
||||
assert is_running is not None
|
||||
print("is_running: ", is_running)
|
||||
|
||||
|
||||
def check_ray_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_ray_dist(PATH)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_ray_dist():
|
||||
spawn(check_ray_dist, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ray_dist()
|
Loading…
Reference in new issue