mirror of https://github.com/hpcaitech/ColossalAI
[inference]Re push async dynamic batching (#4901)
* adapt to ray server * finish async * finish test * del test --------- Co-authored-by: yuehuayingxueluo <867460659@qq.com>pull/4905/head
parent
fced140250
commit
fbf3c09e67
|
@ -4,7 +4,7 @@ from .sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
class Req:
|
class Req:
|
||||||
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams):
|
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt_ids = prompt_ids
|
self.prompt_ids = prompt_ids
|
||||||
self.input_len = len(prompt_ids)
|
self.input_len = len(prompt_ids)
|
||||||
|
@ -14,6 +14,7 @@ class Req:
|
||||||
self.output_metadata_list = []
|
self.output_metadata_list = []
|
||||||
self.has_generate_finished = False
|
self.has_generate_finished = False
|
||||||
self.aborted = False
|
self.aborted = False
|
||||||
|
self.prompts = prompts
|
||||||
|
|
||||||
def to_rpc_obj(self):
|
def to_rpc_obj(self):
|
||||||
return {
|
return {
|
||||||
|
@ -36,7 +37,11 @@ class Req:
|
||||||
if self.sample_params.stop_sequences is not None:
|
if self.sample_params.stop_sequences is not None:
|
||||||
for stop_token_ids in self.sample_params.stop_sequences:
|
for stop_token_ids in self.sample_params.stop_sequences:
|
||||||
stop_len = len(stop_token_ids)
|
stop_len = len(stop_token_ids)
|
||||||
if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
|
if (
|
||||||
|
stop_len > 0
|
||||||
|
and len(self.output_ids) >= stop_len
|
||||||
|
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import time
|
|
||||||
from typing import List
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from .dynamic_batching.infer_batch import InferBatch
|
from .dynamic_batching.infer_batch import InferBatch
|
||||||
from .dynamic_batching.io_struct import Batch, Req
|
from .dynamic_batching.io_struct import Batch, Req
|
||||||
|
@ -9,9 +10,9 @@ from .dynamic_batching.sampling_params import SamplingParams
|
||||||
from .dynamic_batching.stats import Stats
|
from .dynamic_batching.stats import Stats
|
||||||
from .tensor_parallel import TPInferEngine
|
from .tensor_parallel import TPInferEngine
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
||||||
|
|
||||||
|
|
||||||
class DynamicBatchManager:
|
class DynamicBatchManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -19,6 +20,7 @@ class DynamicBatchManager:
|
||||||
max_total_token_num,
|
max_total_token_num,
|
||||||
batch_max_tokens,
|
batch_max_tokens,
|
||||||
eos_id,
|
eos_id,
|
||||||
|
model,
|
||||||
log_stats=True,
|
log_stats=True,
|
||||||
log_stats_interval=10,
|
log_stats_interval=10,
|
||||||
running_batch: Batch = None,
|
running_batch: Batch = None,
|
||||||
|
@ -30,6 +32,7 @@ class DynamicBatchManager:
|
||||||
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
|
||||||
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
|
||||||
eos_id : The end token of a seq
|
eos_id : The end token of a seq
|
||||||
|
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
|
||||||
log_stats : whether to log stats
|
log_stats : whether to log stats
|
||||||
log_stats_interval : log stats interval
|
log_stats_interval : log stats interval
|
||||||
running_batch : running batch
|
running_batch : running batch
|
||||||
|
@ -45,30 +48,30 @@ class DynamicBatchManager:
|
||||||
self.eos_id = eos_id
|
self.eos_id = eos_id
|
||||||
self.has_wait_tokens = 0
|
self.has_wait_tokens = 0
|
||||||
self.max_wait_tokens = 10
|
self.max_wait_tokens = 10
|
||||||
|
self.model = model
|
||||||
|
|
||||||
self.stats_tool = Stats(log_stats, log_stats_interval)
|
self.stats_tool = Stats(log_stats, log_stats_interval)
|
||||||
self.mem_usage_interval = log_stats_interval * 2
|
self.mem_usage_interval = log_stats_interval * 2
|
||||||
|
self._set_tokenizer(tokenizer_name=self.model)
|
||||||
|
|
||||||
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str):
|
async def add_req(self, request_id, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
|
||||||
"""
|
"""
|
||||||
Add new request to req queue, during initialization all requests are held in waiting list.
|
Add new request to req queue, during initialization all requests are held in waiting list.
|
||||||
"""
|
"""
|
||||||
req = Req(request_id, prompt_ids, sampling_params)
|
req = Req(request_id, prompt_ids, sampling_params, prompts)
|
||||||
self.req_queue.append(req)
|
self.req_queue.append(req)
|
||||||
return
|
return
|
||||||
|
|
||||||
def add_input(self, request_id, sampling_params, input_ids):
|
async def add_input(self, request_id, sampling_params, prompts):
|
||||||
"""
|
"""
|
||||||
Encode and Add new input to req queue. support one sequence input for now.
|
Encode and Add new input to req queue. support one sequence input for now.
|
||||||
"""
|
"""
|
||||||
prompt_ids = self.tokenizer.encode(input_ids)
|
prompt_ids = self.tokenizer.encode(prompts)
|
||||||
prompt_len = len(prompt_ids)
|
prompt_len = len(prompt_ids)
|
||||||
if prompt_len > self.engine.max_input_len:
|
if prompt_len > self.engine.max_input_len:
|
||||||
raise ValueError(
|
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
|
||||||
f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}"
|
|
||||||
)
|
|
||||||
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
|
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
|
||||||
self.add_req(prompt_ids, sampling_params, request_id)
|
self.add_req(request_id, prompt_ids, sampling_params, prompts)
|
||||||
return
|
return
|
||||||
|
|
||||||
def abort(self, request_id):
|
def abort(self, request_id):
|
||||||
|
@ -90,8 +93,13 @@ class DynamicBatchManager:
|
||||||
counter_count = 0
|
counter_count = 0
|
||||||
# self.running_batch is not None or self.req_queue.waiting_req_list
|
# self.running_batch is not None or self.req_queue.waiting_req_list
|
||||||
while True:
|
while True:
|
||||||
async for item in self._step():
|
if self.running_batch is not None or self.req_queue.waiting_req_list:
|
||||||
yield item
|
async for result in self._step():
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
# need to wait for new requests
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
continue
|
||||||
counter_count += 1
|
counter_count += 1
|
||||||
if self.running_batch is not None:
|
if self.running_batch is not None:
|
||||||
if counter_count % self.mem_usage_interval == 0:
|
if counter_count % self.mem_usage_interval == 0:
|
||||||
|
@ -103,10 +111,9 @@ class DynamicBatchManager:
|
||||||
)
|
)
|
||||||
self.stats_tool.print_stats()
|
self.stats_tool.print_stats()
|
||||||
|
|
||||||
if self.running_batch is None:
|
def _set_tokenizer(
|
||||||
time.sleep(0.1) # 10ms
|
self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast: bool = True
|
||||||
|
):
|
||||||
def _set_tokenizer(self, tokenizer=None, tokenizer_name: str = "", trust_remote_code: bool = False, use_fast:bool = True,):
|
|
||||||
if tokenizer is not None:
|
if tokenizer is not None:
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
else:
|
else:
|
||||||
|
@ -115,18 +122,22 @@ class DynamicBatchManager:
|
||||||
"For some LLaMA-based models, initializing the fast tokenizer may "
|
"For some LLaMA-based models, initializing the fast tokenizer may "
|
||||||
"take a long time. To eliminate the initialization time, consider "
|
"take a long time. To eliminate the initialization time, consider "
|
||||||
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
||||||
"tokenizer. This is done automatically in Colossalai.")
|
"tokenizer. This is done automatically in Colossalai."
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_name = _FAST_LLAMA_TOKENIZER
|
tokenizer_name = _FAST_LLAMA_TOKENIZER
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
except TypeError as e:
|
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
use_fast = False
|
use_fast = False
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=use_fast,trust_remote_code=trust_remote_code)
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _step(self):
|
||||||
def _step(self):
|
|
||||||
"""
|
"""
|
||||||
Logic for handling requests
|
Logic for handling requests
|
||||||
"""
|
"""
|
||||||
|
@ -136,14 +147,15 @@ class DynamicBatchManager:
|
||||||
if new_batch is not None:
|
if new_batch is not None:
|
||||||
self.stats_tool.count_prompt_tokens(new_batch)
|
self.stats_tool.count_prompt_tokens(new_batch)
|
||||||
self.running_batch = new_batch
|
self.running_batch = new_batch
|
||||||
yield from self._prefill_batch(self.running_batch)
|
async for item in self._prefill_batch(self.running_batch):
|
||||||
|
yield item
|
||||||
self._filter_runing_batch()
|
self._filter_runing_batch()
|
||||||
self.has_wait_tokens = 0
|
self.has_wait_tokens = 0
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.has_wait_tokens < self.max_wait_tokens:
|
if self.has_wait_tokens < self.max_wait_tokens:
|
||||||
self.stats_tool.count_output_tokens(self.running_batch)
|
self.stats_tool.count_output_tokens(self.running_batch)
|
||||||
yield from self._decode_batch(self.running_batch)
|
self._decode_batch(self.running_batch)
|
||||||
self._filter_runing_batch()
|
self._filter_runing_batch()
|
||||||
self.has_wait_tokens += 1
|
self.has_wait_tokens += 1
|
||||||
return
|
return
|
||||||
|
@ -151,7 +163,8 @@ class DynamicBatchManager:
|
||||||
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
|
||||||
if new_mini_batch is not None:
|
if new_mini_batch is not None:
|
||||||
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
self.stats_tool.count_prompt_tokens(new_mini_batch)
|
||||||
yield from self._prefill_batch(new_mini_batch)
|
async for item in self._prefill_batch(new_mini_batch):
|
||||||
|
yield item
|
||||||
if not new_mini_batch.is_clear():
|
if not new_mini_batch.is_clear():
|
||||||
self._merge_batch(self.running_batch, new_mini_batch)
|
self._merge_batch(self.running_batch, new_mini_batch)
|
||||||
self.running_batch.merge(new_mini_batch)
|
self.running_batch.merge(new_mini_batch)
|
||||||
|
@ -159,7 +172,8 @@ class DynamicBatchManager:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.stats_tool.count_output_tokens(self.running_batch)
|
self.stats_tool.count_output_tokens(self.running_batch)
|
||||||
yield from self._decode_batch(self.running_batch)
|
async for item in self._decode_batch(self.running_batch):
|
||||||
|
yield item
|
||||||
self._filter_runing_batch()
|
self._filter_runing_batch()
|
||||||
self.has_wait_tokens += 1
|
self.has_wait_tokens += 1
|
||||||
|
|
||||||
|
@ -187,7 +201,7 @@ class DynamicBatchManager:
|
||||||
)
|
)
|
||||||
self.engine.cache[batch_id] = batch_data
|
self.engine.cache[batch_id] = batch_data
|
||||||
|
|
||||||
def _prefill_batch(self, batch):
|
async def _prefill_batch(self, batch):
|
||||||
"""
|
"""
|
||||||
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
|
||||||
"""
|
"""
|
||||||
|
@ -198,11 +212,11 @@ class DynamicBatchManager:
|
||||||
req_to_out_token_id = ans
|
req_to_out_token_id = ans
|
||||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
async for item in self._handle_finish_req(batch, has_new_finished_req):
|
||||||
|
yield item
|
||||||
# delete finished reqs
|
# delete finished reqs
|
||||||
|
|
||||||
def _decode_batch(self, batch: Batch):
|
async def _decode_batch(self, batch: Batch):
|
||||||
"""
|
"""
|
||||||
Decoding process
|
Decoding process
|
||||||
"""
|
"""
|
||||||
|
@ -210,7 +224,8 @@ class DynamicBatchManager:
|
||||||
req_to_out_token_id = ans
|
req_to_out_token_id = ans
|
||||||
self._add_token_id_to_req(batch, req_to_out_token_id)
|
self._add_token_id_to_req(batch, req_to_out_token_id)
|
||||||
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
has_new_finished_req = batch.mark_finished_req(self.eos_id)
|
||||||
yield from self._handle_finish_req(batch, has_new_finished_req)
|
async for item in self._handle_finish_req(batch, has_new_finished_req):
|
||||||
|
yield item
|
||||||
|
|
||||||
def _filter_batch(self, batch: Batch):
|
def _filter_batch(self, batch: Batch):
|
||||||
batch_id = batch.batch_id
|
batch_id = batch.batch_id
|
||||||
|
@ -240,15 +255,15 @@ class DynamicBatchManager:
|
||||||
batch.free_self()
|
batch.free_self()
|
||||||
del batch
|
del batch
|
||||||
|
|
||||||
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
async def _handle_finish_req(self, batch: Batch, has_new_finished_req):
|
||||||
if has_new_finished_req:
|
if has_new_finished_req:
|
||||||
finished_reqs = batch.filter_finished()
|
finished_reqs = batch.filter_finished()
|
||||||
if batch.is_clear():
|
if batch.is_clear():
|
||||||
self._remove_batch(batch)
|
self._remove_batch(batch)
|
||||||
else:
|
else:
|
||||||
self._filter_batch(batch)
|
self._filter_batch(batch)
|
||||||
yield from self._output_process(finished_reqs)
|
async for item in self._output_process(finished_reqs):
|
||||||
|
yield item
|
||||||
|
|
||||||
def _filter_runing_batch(self):
|
def _filter_runing_batch(self):
|
||||||
if self.running_batch is not None and self.running_batch.is_clear():
|
if self.running_batch is not None and self.running_batch.is_clear():
|
||||||
|
@ -267,7 +282,7 @@ class DynamicBatchManager:
|
||||||
"""
|
"""
|
||||||
for req in finished_reqs:
|
for req in finished_reqs:
|
||||||
output = self.tokenizer.decode(req.output_ids)
|
output = self.tokenizer.decode(req.output_ids)
|
||||||
yield output, req.request_id, req.output_metadata_list
|
yield req.prompts + output
|
||||||
|
|
||||||
def clean_up(self):
|
def clean_up(self):
|
||||||
# this logic should be implemented in the future.
|
# this logic should be implemented in the future.
|
||||||
|
@ -277,7 +292,13 @@ class DynamicBatchManager:
|
||||||
"""
|
"""
|
||||||
Generate the output of a request.
|
Generate the output of a request.
|
||||||
"""
|
"""
|
||||||
self.add_input(request_id,prompt_id,sampling_params)
|
|
||||||
|
await self.add_input(request_id, prompt_id, sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
async def process_data(dbm):
|
||||||
|
async for data in dbm.loop_for_fwd():
|
||||||
|
print(data)
|
||||||
|
|
||||||
|
|
||||||
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||||
|
@ -287,21 +308,13 @@ def start_dynamic_batching(args, tp_engine, waiting_req_list):
|
||||||
max_total_token_num=args.max_total_token_num,
|
max_total_token_num=args.max_total_token_num,
|
||||||
batch_max_tokens=args.batch_max_tokens,
|
batch_max_tokens=args.batch_max_tokens,
|
||||||
eos_id=args.eos_id,
|
eos_id=args.eos_id,
|
||||||
|
model=args.model,
|
||||||
log_stats=not args.disable_log_stats,
|
log_stats=not args.disable_log_stats,
|
||||||
log_stats_interval=args.log_stats_interval,
|
log_stats_interval=args.log_stats_interval,
|
||||||
waiting_req_list=waiting_req_list,
|
waiting_req_list=waiting_req_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
batch_manager.clean_up()
|
raise RuntimeError("Failed to start dynamic batching")
|
||||||
raise
|
|
||||||
|
|
||||||
batch_manager._set_tokenizer(tokenizer_name = tp_engine.model.__class__.__name__)
|
|
||||||
prod_task = asyncio.create_task(batch_manager.add_input(4,sampling_params=SamplingParams(),input_ids="hello world"))
|
|
||||||
|
|
||||||
asyncio.run(prod_task)
|
|
||||||
|
|
||||||
for item in batch_manager.loop_for_fwd():
|
|
||||||
print(item)
|
|
||||||
|
|
||||||
return batch_manager
|
return batch_manager
|
||||||
|
|
|
@ -1,33 +0,0 @@
|
||||||
import asyncio
|
|
||||||
|
|
||||||
shared_list = []
|
|
||||||
|
|
||||||
async def producer():
|
|
||||||
for i in range(5):
|
|
||||||
await asyncio.sleep(1) # 模拟异步获取数据的操作
|
|
||||||
shared_list.append(i)
|
|
||||||
print(f"Produced {i}")
|
|
||||||
|
|
||||||
async def consumer():
|
|
||||||
last_index = 0
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(0.5) # 为了不使循环过于紧凑,增加了小的延迟
|
|
||||||
if last_index < len(shared_list):
|
|
||||||
item = shared_list[last_index]
|
|
||||||
print(f"Consumed {item}")
|
|
||||||
yield item
|
|
||||||
last_index += 1
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# 创建生产者和消费者任务
|
|
||||||
prod_task = asyncio.create_task(producer())
|
|
||||||
|
|
||||||
# 等待生产者任务完成
|
|
||||||
await prod_task
|
|
||||||
|
|
||||||
async for data in consumer():
|
|
||||||
print(data)
|
|
||||||
# 为了示例的目的,我们只等待一段时间,然后停止消费者
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
@ -5,10 +8,9 @@ from transformers import LlamaForCausalLM
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from dataclasses import dataclass
|
|
||||||
from colossalai.inference.dynamic_batching.io_struct import Req
|
from colossalai.inference.dynamic_batching.io_struct import Req
|
||||||
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
|
||||||
from colossalai.inference.manager import start_dynamic_batching
|
from colossalai.inference.manager import process_data, start_dynamic_batching
|
||||||
from colossalai.inference.tensor_parallel import TPInferEngine
|
from colossalai.inference.tensor_parallel import TPInferEngine
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||||
|
@ -19,17 +21,26 @@ MAX_INPUT_LEN = 5
|
||||||
MAX_OUTPUT_LEN = 16
|
MAX_OUTPUT_LEN = 16
|
||||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class args:
|
class args:
|
||||||
max_total_token_num: int
|
max_total_token_num: int
|
||||||
batch_max_tokens: int
|
batch_max_tokens: int
|
||||||
eos_id: int
|
eos_id: int
|
||||||
|
model: str
|
||||||
disable_log_stats: bool
|
disable_log_stats: bool
|
||||||
log_stats_interval: int
|
log_stats_interval: int
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10)
|
arg = args(
|
||||||
|
max_total_token_num=42,
|
||||||
|
batch_max_tokens=42,
|
||||||
|
eos_id=0,
|
||||||
|
model="llama",
|
||||||
|
disable_log_stats=False,
|
||||||
|
log_stats_interval=10,
|
||||||
|
)
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
|
||||||
|
@ -51,12 +62,14 @@ def run():
|
||||||
|
|
||||||
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
|
||||||
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
|
||||||
manager._set_tokenizer(tokenizer_name = model.__class__.__name__)
|
asyncio.run(test(manager))
|
||||||
result_generator = manager.loop_for_fwd()
|
|
||||||
for result in result_generator:
|
|
||||||
print(result)
|
|
||||||
|
|
||||||
|
|
||||||
|
async def test(manager):
|
||||||
|
asyncio.create_task(process_data(manager))
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
await manager.add_req(4, [0, 0, 10, 10, 10], SamplingParams())
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
def check_dynamic_forward(rank, world_size, port):
|
def check_dynamic_forward(rank, world_size, port):
|
||||||
|
|
Loading…
Reference in New Issue