[Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>
pull/4965/head
Jianghai 1 year ago committed by GitHub
parent 4e4a10c97d
commit cf579ff46d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,133 @@
import asyncio
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
from .dynamic_batching.io_struct import RequestOutput
from .dynamic_batching.sampling_params import SamplingParams
class RequestTracker:
"""
A class for trace down all the requests, abstraction for async
"""
def __init__(self) -> None:
self._requests: asyncio.Queue[str] = asyncio.Queue()
self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._requests
def init_event(self):
self.new_requests_event = asyncio.Event()
def add_request(self, request_id: str):
"""Add a request to be sent to the engine on the next background
loop iteration."""
self._requests.put_nowait(request_id)
self.new_requests_event.set() # NOTE: we may find a better way to clear this event
def add_stop(self):
"""
Add a StopIteration flag to stop async generator.
"""
self._finished_requests.put_nowait(StopIteration)
self.new_requests_event.clear()
def process_request_output(self, request_output: RequestOutput) -> None:
"""Process a request output from the engine."""
self._finished_requests.put_nowait(request_output)
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
def __aiter__(self):
return self
async def __anext__(self) -> RequestOutput:
result = await self._finished_requests.get()
# print("result of ", result)
if result is StopIteration:
raise StopAsyncIteration
return result
class Async_Engine:
"""
Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
Background loop: inference reqs in waiting list (Listen)
Request Tracker: manage incoming requests and restore finished ones
Generate: exposed func for add new input and return finished ones
"""
def __init__(
self,
router_config,
engine_config,
start_engine_loop: bool = True,
) -> None:
self.driver = Driver(router_config=router_config, engine_config=engine_config)
self.background_loop = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
def _step(self):
"""
Logic for handling requests
"""
request_outputs = self.driver.step()
if request_outputs is not None:
for request_output in request_outputs:
self._request_tracker.process_request_output(request_output)
self._request_tracker.add_stop()
def abort_request(self, request_id: str):
self.driver.abort(request_id)
def _has_requests_in_progress(self):
return self.driver.is_running()
async def run_loop_fwd(self):
has_requests_in_progress = self._has_requests_in_progress()
while True:
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
self._step()
await asyncio.sleep(0)
@property
def is_running(self):
return self.background_loop is not None and not self.background_loop.done()
def start_background_loop(self):
if self.is_running:
raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event()
self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
self.background_loop = asyncio.shield(self.background_loop_unshielded)
async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.driver.add_input(request_id, prompt, sampling_params)
self._request_tracker.add_request(request_id)
async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
"""
The only exposed func, adding new request and return a async generator that yields the existing results.
"""
try:
if not self.is_running:
self.start_background_loop()
await self.add_request(request_id, prompt, sampling_params)
async for request_output in self._request_tracker:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the request.
self.abort_request(request_id)
raise e

@ -0,0 +1,151 @@
from typing import List
from .dynamic_batching.io_struct import Batch, Req, RequestOutput
from .manager import DynamicBatchManager
from .tensor_parallel import TPInferEngine
class Async_DynamicBatchManager(DynamicBatchManager):
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num: int,
batch_max_tokens: int,
model: str,
tokenizer=None,
eos_id=None,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
super().__init__(
tp_engine,
max_total_token_num,
batch_max_tokens,
model,
tokenizer,
eos_id,
log_stats,
log_stats_interval,
running_batch,
waiting_req_list,
)
def _step(self):
"""
Logic for handling requests
"""
has_new_finished = False
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
has_new_finished, outputs = self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
else:
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
has_new_finished, outputs = self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0
else:
self.stats_tool.count_output_tokens(self.running_batch)
has_new_finished, outputs = self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
if has_new_finished:
return outputs
return None
def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)
# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
# delete finished reqs
def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
outputs = self._handle_finish_req(batch, has_new_finished_req)
return has_new_finished_req, outputs
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
return self._output_process(finished_reqs)
return None
def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
outputs = []
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
return outputs
def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = Async_DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)
except Exception:
raise Exception
return batch_manager

@ -0,0 +1,40 @@
"""
Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.
license: MIT, see LICENSE for more details.
"""
from transformers import AutoTokenizer
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
def get_tokenizer(
tokenizer=None,
tokenizer_name: str = "",
trust_remote_code: bool = False,
use_fast: bool = True,
):
if tokenizer is not None:
tokenizer = tokenizer
else:
if "llama" in tokenizer_name.lower() and use_fast == True:
print(
"For some LLaMA-based models, initializing the fast tokenizer may "
"take a long time. To eliminate the initialization time, consider "
f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
"tokenizer. This is done automatically in Colossalai."
)
tokenizer_name = _FAST_LLAMA_TOKENIZER
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
except TypeError:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
)
return tokenizer

@ -0,0 +1,346 @@
# Adapted from https://github.com/ModelTC/lightllm
import collections
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.inference.tensor_parallel import MemoryManager
# make batch infer state an attr of InferBatch
class InferSamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
vocab_size: int = -1,
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
if self.top_k == -1:
self.top_k = vocab_size
return
@dataclass
class InferBatch:
batch_id: int
requests: List
requests_idx_mapping: Dict[int, int]
input_ids: torch.Tensor
all_input_ids: List[List[int]]
input_lengths: List[int]
out_token_id_counts: List
sampling_param_list: List[InferSamplingParams]
nopad_total_token_num: int
nopad_max_len_in_batch: int
nopad_b_loc: torch.Tensor
nopad_b_start_loc: torch.Tensor
nopad_b_seq_len: torch.Tensor
cache_manager: MemoryManager
max_total_len: int
@classmethod
@torch.no_grad()
def init_batch(
cls,
batch_id,
requests,
dtype: torch.dtype,
device: torch.device,
cache_manager: MemoryManager,
vocab_size: int,
max_total_len: int,
) -> "InferBatch":
input_lengths = []
all_input_ids = []
requests_idx_mapping = {}
out_token_id_counts = []
sampling_param_list = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
# to avoid memory leak , we pre-allocate 12 more space for each batch.
nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
for i, r in enumerate(requests):
# request id -> idx in list mapping
requests_idx_mapping[r["request_id"]] = i
tokenized_input = r["input_id"]
input_length = len(tokenized_input)
input_lengths.append(input_length)
all_input_ids.append(tokenized_input)
out_token_id_counts.append(collections.defaultdict(int))
# postprocessor
sampling_param = r["sampling_param"]
sampling_param["vocab_size"] = vocab_size
sampling_param_list.append(InferSamplingParams(**sampling_param))
nopad_total_token_num += input_length
nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
if len(requests) > 1:
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
else:
input_ids = all_input_ids[0]
# Create tensors on device
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
return cls(
batch_id=batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=cache_manager,
max_total_len=max_total_len,
)
@torch.no_grad()
def free_self(self) -> None:
"""
Free the memory of the InferBatch itself
"""
remove_index = []
for idx in range(len(self)):
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
@torch.no_grad()
def filter(self, request_ids: List[int]) -> "InferBatch":
"""
Filter finished batch and return a new InferBatch with left ones.
"""
if len(request_ids) == 0:
raise ValueError("Batch must have at least one request")
if len(request_ids) == len(self):
return self
requests_idx_mapping = {}
indices = []
requests = []
all_input_ids = []
input_lengths = []
nopad_total_token_num = 0
nopad_max_len_in_batch = 0
nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
left_idx = []
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
left_idx.append(idx)
left_idx_set = set(left_idx)
remove_index = []
for idx in range(len(self)):
if idx not in left_idx_set:
remove_index.append(
self.nopad_b_loc[
idx,
(self.nopad_max_len_in_batch - 1)
- (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
]
)
remove_index = torch.cat(remove_index, dim=-1)
self.cache_manager.free(remove_index)
nopad_max_len_in_batch = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
indices,
(self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
]
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
all_input_ids.append(self.all_input_ids[idx])
input_lengths.append(self.input_lengths[idx])
input_ids = self.input_ids[indices]
return InferBatch(
batch_id=self.batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
cache_manager=self.cache_manager,
max_total_len=self.max_total_len,
)
@classmethod
@torch.no_grad()
def merge(cls, batch1, batch2) -> "InferBatch":
"""
Return megerd new InferBatch
"""
requests = batch1.requests + batch2.requests
requests_idx_mapping = {}
new_batch_size = len(batch1) + len(batch2)
input_ids = batch1.input_ids.new_empty(new_batch_size)
all_input_ids = []
input_lengths = []
out_token_id_counts = []
sampling_param_list = []
cumulative_batch_size = 0
nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
max_total_len = max(batch1.max_total_len, batch2.max_total_len)
nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
nopad_start_loc_len_temp = 0
batches = [batch1, batch2]
for i, batch in enumerate(batches):
if i == 0:
requests_idx_mapping = batch.requests_idx_mapping
else:
for k, v in batch.requests_idx_mapping.items():
requests_idx_mapping[k] = v + cumulative_batch_size
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
input_ids[start_index:end_index] = batch.input_ids
nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
nopad_b_loc[
start_index:end_index,
nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
all_input_ids.extend(batch.all_input_ids)
input_lengths.extend(batch.input_lengths)
out_token_id_counts.extend(batch.out_token_id_counts)
sampling_param_list.extend(batch.sampling_param_list)
# Update
cumulative_batch_size += len(batch)
nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
)
return InferBatch(
batch_id=batches[0].batch_id,
requests=requests,
requests_idx_mapping=requests_idx_mapping,
input_ids=input_ids,
input_lengths=input_lengths,
all_input_ids=all_input_ids,
nopad_total_token_num=nopad_total_token_num,
nopad_max_len_in_batch=nopad_max_len_in_batch,
nopad_b_loc=nopad_b_loc,
nopad_b_start_loc=nopad_b_start_loc,
nopad_b_seq_len=nopad_b_seq_len,
out_token_id_counts=out_token_id_counts,
sampling_param_list=sampling_param_list,
cache_manager=batches[0].cache_manager,
max_total_len=max_total_len,
)
def __len__(self):
return len(self.requests)
def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
temperatures: List[float] = []
top_ps: List[float] = []
top_ks: List[int] = []
p_token_ids: List[int] = []
p_token_counts: List[int] = []
p_seq_len: List[int] = [
0,
]
p_max_len_in_batch: int = 0
for i, id_to_count in enumerate(self.out_token_id_counts):
sample_param = self.sampling_param_list[i]
presence_penalties.append(sample_param.presence_penalty)
frequency_penalties.append(sample_param.frequency_penalty)
temperatures.append(sample_param.temperature)
top_ps.append(sample_param.top_p)
top_ks.append(sample_param.top_k)
for token_id, count in id_to_count.items():
p_token_ids.append(token_id)
p_token_counts.append(count)
p_seq_len.append(len(id_to_count))
p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
return (
presence_penalties,
frequency_penalties,
temperatures,
top_ps,
top_ks,
p_token_ids,
p_token_counts,
p_cumsum_seq_len,
p_max_len_in_batch,
)

@ -0,0 +1,166 @@
# Adapted from https://github.com/ModelTC/lightllm
from typing import Dict, List, Tuple
from .sampling_params import SamplingParams
class Req:
def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
self.request_id = request_id
self.prompt_ids = prompt_ids
self.input_len = len(prompt_ids)
self.max_output_len = sample_params.max_new_tokens
self.sample_params = sample_params
self.output_ids = []
self.output_metadata_list = []
self.has_generate_finished = False
self.aborted = False
self.prompts = prompts
def to_rpc_obj(self):
return {
"request_id": self.request_id,
"input_id": self.prompt_ids,
"output_len": self.max_output_len,
"sampling_param": self.sample_params.to_dict(),
}
def stop_sequences_matched(self):
# should we add stpp sequences to the sample params?
if self.sample_params.stop_sequences is not None:
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if (
stop_len > 0
and len(self.output_ids) >= stop_len
and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
):
return True
return False
def __repr__(self):
return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
class Batch:
def __init__(self, batch_id, reqs: List[Req]):
self.batch_id = batch_id
self.reqs = reqs
self.id_to_reqs = {req.request_id: req for req in reqs}
def input_tokens(self):
batch_input_tokens = 0
for req in self.reqs:
batch_input_tokens += req.input_len
return batch_input_tokens
def calcu_max_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + req.max_output_len
return tokens
def calcu_used_tokens(self):
tokens = 0
for req in self.reqs:
tokens += req.input_len + len(req.output_ids)
return tokens
def mark_finished_req(self, eos_id, engine_max_output_len):
has_new_finish = False
for req in self.reqs:
if req.stop_sequences_matched():
req.has_generate_finished = True
has_new_finish = True
if len(req.output_ids) >= engine_max_output_len:
req.has_generate_finished = True
has_new_finish = True
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
req.has_generate_finished = True
has_new_finish = True
if len(req.output_ids) >= req.max_output_len or req.aborted:
req.has_generate_finished = True
has_new_finish = True
return has_new_finish
def filter_finished(self) -> List[Req]:
"""
Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
"""
# TODO: the logic of return should be defined here.
unfinished_req = []
finished_req = []
for req in self.reqs:
if not req.has_generate_finished:
unfinished_req.append(req)
else:
finished_req.append(req)
self.reqs = unfinished_req
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return finished_req
def is_clear(self):
return len(self.reqs) == 0
def merge(self, mini_batch):
for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return
def __repr__(self):
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
def __len__(self):
return len(self.reqs)
class BatchTokenIdOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, int, Dict, bool, bool]
] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
class BatchStrOut:
def __init__(self):
self.reqs_infs: List[
Tuple[str, str, Dict, bool, bool]
] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
class AbortReq:
def __init__(self, req_id):
self.req_id = req_id
class RequestOutput:
"""The output data of a request to the LLM.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
prompt_token_ids: The token IDs of the prompt.
outputs: The output sequences of the request.
"""
def __init__(
self,
request_id: str,
prompt: str,
prompt_token_ids: List[int],
outputs,
) -> None:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.outputs = outputs
def __repr__(self) -> str:
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"outputs={self.outputs}, "
)

@ -0,0 +1,152 @@
import logging
import os
from typing import List
import ray
import ray.util.collective as collective
import torch
from transformers import AutoModelForCausalLM
import colossalai
from colossalai.inference.async_manager import start_dynamic_batching
from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer
from colossalai.inference.dynamic_batching.io_struct import RequestOutput
from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import free_port
ray_serve_logger = logging.getLogger("ray.serve")
def log_cuda_info(scope_name: str):
ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
ray_serve_logger.info(
f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
)
if torch.cuda.is_available():
ray_serve_logger.info(
f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
)
else:
ray_serve_logger.info(f" {scope_name}: cuda is not available!")
@ray.remote(num_gpus=1)
class Worker:
def __init__(
self,
model_path: str,
tensor_parallel_size: int,
max_batch_size: int,
max_input_len: int,
max_output_len: int,
router_config: RooterArgsClass,
):
log_cuda_info("Worker.init")
self.tensor_parallel_size = tensor_parallel_size
self.model_path = model_path
self.max_batch_size = max_batch_size
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.router_config = router_config
def setup(self, world_size, rank, port):
# initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
collective.init_collective_group(world_size, rank, "nccl", "default")
# initialize and set distributed environment
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
log_cuda_info("Worker.setup")
# Load model
self.tokenizer = get_tokenizer(tokenizer_name=self.model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, [])
return True
# def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]:
# ray_serve_logger.info(f"text: {prompt}")
# final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id)
# return final_outputs
def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
self.start_dynamic_batching.add_input(request_id, prompt, sampling_params)
def abort(self, request_id: str):
self.start_dynamic_batching.abort(request_id)
def step(self) -> List[RequestOutput]:
return self.start_dynamic_batching._step()
def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt)
def is_running(self):
return self.start_dynamic_batching.is_running()
class Driver:
def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass):
log_cuda_info("Driver:init")
model_path = engine_config.model
tensor_parallel_size = engine_config.tensor_parallel_size
self.num_workers = tensor_parallel_size
self.workers = []
init_rets = []
# Just grab a free port on localhost
# NOTE workers in this communication group listen to the same port
available_port = free_port()
for i in range(self.num_workers):
worker_name = "worker_idx_{}".format(i)
w = Worker.options(name=worker_name).remote(
model_path,
self.num_workers,
engine_config.max_batch_size,
engine_config.max_input_len,
engine_config.max_output_len,
router_config,
)
self.workers.append(w)
init_rets.append(w.setup.remote(self.num_workers, i, available_port))
_options = {
"group_name": "default_driver",
"world_size": self.num_workers,
"ranks": [i for i in range(self.num_workers)],
"backend": "nccl",
}
collective.create_collective_group(self.workers, **_options)
_ = ray.get(init_rets)
def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers])
def abort(self, request_id: str):
ray.get([w.abort.remote(request_id) for w in self.workers])
def step(self):
results = ray.get([w.step.remote() for w in self.workers])
outputs = results[0] # get any one of the copies
return outputs
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str):
ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])
def is_running(self):
results = ray.get([w.is_running.remote() for w in self.workers])
return any(results)

@ -0,0 +1,58 @@
import logging
import yaml
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EngineArgsClass(BaseModel):
"""Config for Engine"""
model: str
tensor_parallel_size: int = 2
max_batch_size: int = 4
max_input_len: int = 128
max_output_len: int = 32
class RooterArgsClass(BaseModel):
"""Config for Rooter"""
max_total_token_num: int = 42
batch_max_tokens: int = 42
eos_id: int = 0
disable_log_stats: bool = False
log_stats_interval: int = 10
model: str
class RayInitConfig(BaseModel):
"""All-together configs without app router config"""
engine_config_data: EngineArgsClass
router_config_data: RooterArgsClass
@classmethod
def from_yaml_path(cls, path: str):
try:
with open(path, "r") as yaml_file:
try:
config = yaml.safe_load(yaml_file)
# serve deployment config
engine_config = config.get("engine_config", {})
router_config = config.get("router_config", {})
return cls(
engine_config_data=engine_config,
router_config_data=router_config,
)
except yaml.YAMLError as e:
logger.error(f"An Error occurred when parsing yaml: {e}")
raise
except FileNotFoundError:
logger.error(f"The file '{path}' does not exist!")
raise
except OSError as e:
logger.error(f"An Error occurred: {e}")
raise

@ -0,0 +1,73 @@
# Adapted from https://github.com/ModelTC/lightllm
import uuid
from typing import List
import numpy as np
from .io_struct import Batch, Req
class ReqQueue:
def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
self.max_total_tokens = max_total_tokens
assert batch_max_tokens is not None
self.batch_max_tokens = batch_max_tokens
self.running_max_req_size = running_max_req_size
self.waiting_req_list: List[Req] = waiting_req_list
def append(self, req):
self.waiting_req_list.append(req)
return
def _init_cache_list(self, current_batch: Batch):
if current_batch is not None:
self.cache_len_list = [
(req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
for req in current_batch.reqs
]
else:
self.cache_len_list = []
# @calculate_time(show=True, min_cost_ms=0.1)
def _can_add_new_req(self, req):
self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
self.cache_len_list.sort(key=lambda x: -x[1])
left_out_len_array = np.array([e[1] for e in self.cache_len_list])
# assert left_out_len_array.min() >= 0
has_run_len_array = np.array([e[0] for e in self.cache_len_list])
cum_run_len_array = np.cumsum(has_run_len_array)
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
# NOTE: change here < to <=
return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
def generate_new_batch(self, current_batch: Batch = None):
if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
return None
self._init_cache_list(current_batch)
can_run_list = []
new_batch_total_tokens = 0
aborted_count = 0
for req in self.waiting_req_list:
flag = self._can_add_new_req(req)
if req.aborted:
aborted_count += 1
continue
if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
can_run_list.append(req)
new_batch_total_tokens += req.input_len
else:
break
if len(can_run_list) != 0:
new_batch = Batch(uuid.uuid4().hex, can_run_list)
self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
return new_batch
else:
return None
def __len__(self):
return self.waiting_req_list.__len__()

@ -0,0 +1,83 @@
# Adapted from https://github.com/ModelTC/lightllm
"""Sampling parameters for text generation."""
from typing import List, Optional, Union
_SAMPLING_EPS = 1e-5
class SamplingParams:
def __init__(
self,
do_sample: bool = False,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 256,
stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
self.top_k = 1
if (
self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS
): # temperature is too slow, change to greedy search
self.temperature = 1.0
self.top_k = 1
return
def verify(self):
if self.presence_penalty < 0.0:
raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
if self.frequency_penalty < 0.0:
raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
if self.temperature <= 0.0:
raise ValueError(f"temperature must > 0.0, got {self.temperature}")
if self.top_p <= 0.0 or self.top_p > 1.0:
raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
return
def stop_sentences_to_token_ids(self, tokenizer):
if self.stop_sequences is None:
self.stop_sequences = []
else:
if isinstance(self.stop_sequences, str):
self.stop_sequences = [self.stop_sequences]
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return
def to_dict(self):
ret = {}
ret["do_sample"] = self.do_sample
ret["presence_penalty"] = self.presence_penalty
ret["frequency_penalty"] = self.frequency_penalty
ret["temperature"] = self.temperature
ret["top_p"] = self.top_p
ret["top_k"] = self.top_k
# if self.ignore_eos is not None:
# ret["ignore_eos"] = self.ignore_eos
return ret

@ -0,0 +1,45 @@
# Adapted from https://github.com/ModelTC/lightllm
import time
class Stats:
def __init__(self, log_status, log_stats_interval) -> None:
self.log_stats = log_status
self.log_stats_interval = log_stats_interval
self.last_log_time = time.time()
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
return
def count_prompt_tokens(self, run_batch):
if self.log_stats:
tokens = run_batch.input_tokens()
self.prompt_tokens += tokens
self.all_tokens += tokens
return
def count_output_tokens(self, run_batch):
if self.log_stats:
tokens = len(run_batch.reqs)
self.output_tokens += tokens
self.all_tokens += tokens
return
def print_stats(self):
if not self.log_stats:
return
now = time.time()
if now - self.last_log_time > self.log_stats_interval:
print(
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
)
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
self.last_log_time = now
return

@ -0,0 +1,296 @@
# Adapted from https://github.com/ModelTC/lightllm
import time
from typing import List
from .dynamic_batching.get_tokenizer import get_tokenizer
from .dynamic_batching.infer_batch import InferBatch
from .dynamic_batching.io_struct import Batch, Req
from .dynamic_batching.req_queue import ReqQueue
from .dynamic_batching.sampling_params import SamplingParams
from .dynamic_batching.stats import Stats
from .tensor_parallel import TPInferEngine
class DynamicBatchManager:
def __init__(
self,
tp_engine: TPInferEngine,
max_total_token_num,
batch_max_tokens,
model,
tokenizer=None,
eos_id=None,
log_stats=True,
log_stats_interval=10,
running_batch: Batch = None,
waiting_req_list: List = [],
):
"""
Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
eos_id : The end token of a seq
model: the model weight dir path, the app will load config, weights and tokenizer from this dir
log_stats : whether to log stats
log_stats_interval : log stats interval
running_batch : running batch
waiting_req_list : list of waiting requests, initialized before dynamic batch manager
"""
self.engine = tp_engine
self.max_total_token_num = max_total_token_num
running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
# all the inputs should be put into req_queue: waiting req list
assert max_total_token_num >= self.engine.max_batch_size * (
self.engine.max_input_len + self.engine.max_output_len
), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)"
assert (
batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len
), "batch_max_tokens should be greater than (max_input_len+max_output_len)"
self.running_batch: Batch = running_batch
self.eos_id = eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = 10
self.model = model
self.stats_tool = Stats(log_stats, log_stats_interval)
self.mem_usage_interval = log_stats_interval * 2
self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer
if self.eos_id == None:
self.eos_id = self.tokenizer.eos_token_id
def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
"""
Add new request to req queue, during initialization all requests are held in waiting list.
"""
sampling_params.max_new_tokens = (
self.engine.max_output_len
if sampling_params.max_new_tokens > self.engine.max_output_len
else sampling_params.max_new_tokens
)
req = Req(request_id, prompt_ids, sampling_params, prompts)
self.req_queue.append(req)
return
def add_input(self, request_id, prompts, sampling_params):
"""
Encode and Add new input to req queue. support one sequence input for now.
"""
prompt_ids = self.tokenizer.encode(prompts)
prompt_len = len(prompt_ids)
if prompt_len > self.engine.max_input_len:
raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
sampling_params.stop_sentences_to_token_ids(self.tokenizer)
self.add_req(request_id, prompt_ids, sampling_params, prompts)
return
def abort(self, request_id):
if self.running_batch is not None:
for req in self.running_batch.reqs:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
for req in self.req_queue.waiting_req_list:
if req.request_id == request_id:
req.has_generate_finished = True
req.aborted = True
return
def loop_for_fwd(self):
"""
The main loop for a dynamic batching process.
"""
counter_count = 0
# self.running_batch is not None or self.req_queue.waiting_req_list
while self.running_batch is not None or self.req_queue.waiting_req_list:
yield from self._step()
counter_count += 1
if self.running_batch is not None:
if counter_count % self.mem_usage_interval == 0:
print(
"current batch size:",
len(self.running_batch.reqs),
"token used ratio:",
self.running_batch.calcu_used_tokens() / self.max_total_token_num,
)
self.stats_tool.print_stats()
if self.running_batch is None:
time.sleep(0.1) # 10ms
def _step(self):
"""
Logic for handling requests
"""
if self.running_batch is None:
new_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_batch is not None:
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
yield from self._prefill_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens = 0
return
if self.has_wait_tokens < self.max_wait_tokens:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
else:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
yield from self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0
else:
self.stats_tool.count_output_tokens(self.running_batch)
yield from self._decode_batch(self.running_batch)
self._filter_runing_batch()
self.has_wait_tokens += 1
return
def _init_batch(self, batch: Batch, dtype="fp16"):
reqs = [r.to_rpc_obj() for r in batch.reqs]
batch_id = batch.batch_id
import torch
if dtype == "fp16":
dtype = torch.float16
else:
assert False, "error dtype"
batch_data = InferBatch.init_batch(
batch_id,
reqs,
dtype,
torch.cuda.current_device(),
self.engine.cache_manager,
self.engine.model.config.vocab_size,
self.engine.max_input_len + self.engine.max_output_len,
)
self.engine.cache[batch_id] = batch_data
def _prefill_batch(self, batch):
"""
For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
"""
self._init_batch(batch)
# TODO: figure out if cache and batch id is needed
ans = self.engine._prefill_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
yield from self._handle_finish_req(batch, has_new_finished_req)
# delete finished reqs
def _decode_batch(self, batch: Batch):
"""
Decoding process
"""
ans = self.engine._decode_batch(batch.batch_id)
req_to_out_token_id = ans
self._add_token_id_to_req(batch, req_to_out_token_id)
has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
yield from self._handle_finish_req(batch, has_new_finished_req)
def _filter_batch(self, batch: Batch):
batch_id = batch.batch_id
req_id_list = [r.request_id for r in batch.reqs]
batch = self.engine.cache.pop(batch_id)
filter_batch = batch.filter(req_id_list)
del batch
self.engine.cache[batch_id] = filter_batch
def _merge_batch(self, batch1, batch2):
"""
Merge new mini batch into running batch.
"""
batch1 = self.engine.cache.pop(batch1.batch_id)
batch2 = self.engine.cache.pop(batch2.batch_id)
m_batch = InferBatch.merge(batch1, batch2)
self.engine.cache[batch1.batch_id] = m_batch
del batch1
del batch2
def _remove_batch(self, batch):
"""
Remove finished batch.
"""
batch = self.engine.cache.pop(batch.batch_id)
batch.free_self()
del batch
def _handle_finish_req(self, batch: Batch, has_new_finished_req):
if has_new_finished_req:
finished_reqs = batch.filter_finished()
if batch.is_clear():
self._remove_batch(batch)
else:
self._filter_batch(batch)
yield from self._output_process(finished_reqs)
def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None
def _add_token_id_to_req(self, batch: Batch, req_ans):
for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
req = batch.id_to_reqs[req_id]
req.output_ids.append(new_token_id)
req.output_metadata_list.append(new_gen_metadata)
return
def _output_process(self, finished_reqs: List[Req]):
"""
Process the output of a batch.
"""
for req in finished_reqs:
output = self.tokenizer.decode(req.output_ids)
yield req.prompts + output
def clean_up(self):
# this logic should be implemented in the future.
pass
def generate(self, request_id, prompts, sampling_params):
"""
Generate the output of a request.
"""
self.add_input(request_id, prompts, sampling_params)
return self.loop_for_fwd()
def is_running(self):
return self.running_batch is not None or self.req_queue.waiting_req_list
def start_dynamic_batching(args, tp_engine, waiting_req_list):
try:
batch_manager = DynamicBatchManager(
tp_engine=tp_engine,
max_total_token_num=args.max_total_token_num,
batch_max_tokens=args.batch_max_tokens,
eos_id=args.eos_id,
model=args.model,
log_stats=not args.disable_log_stats,
log_stats_interval=args.log_stats_interval,
waiting_req_list=waiting_req_list,
)
except Exception:
raise Exception
return batch_manager

@ -87,7 +87,6 @@ class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()

@ -149,12 +149,6 @@ class LLamaSmoothquantAttention(nn.Module):
self.k_rotary_output_scale.item(),
)
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
@ -229,7 +223,7 @@ class LLamaSmoothquantAttention(nn.Module):
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
@ -592,17 +586,13 @@ def llama_model_forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
infer_state = self.infer_state
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
if past_key_values is not None:
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
seq_length_with_past = seq_length + past_key_values_length
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@ -623,9 +613,7 @@ def llama_model_forward(
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
@ -713,6 +701,7 @@ def llama_model_forward(
infer_state.is_context_stage = False
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
next_cache = next_decoder_cache if use_cache else None
if not return_dict:

@ -13,6 +13,8 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
# from dynamic_batching.infer_batch import InferBatch
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = [
@ -61,7 +63,6 @@ class TPInferEngine:
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)
# Constraints relatable with specs of devices and model
# This may change into an optional arg in the future
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
@ -96,6 +97,8 @@ class TPInferEngine:
self.shard_config = shard_config
self.model = None
self.cache = {}
# optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device))
@ -284,7 +287,6 @@ class TPInferEngine:
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
@ -318,6 +320,7 @@ class TPInferEngine:
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
return batch_infer_state
@torch.no_grad()
@ -381,6 +384,85 @@ class TPInferEngine:
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1
@torch.no_grad()
def forward(self, batch_id, is_prefill):
"""
Forward is used in Dynamic Batching Manager
"""
batch = self.cache.pop(batch_id)
if is_prefill:
input_ = torch.tensor(batch.all_input_ids).cuda()
else:
input_ = batch.input_ids.reshape(len(batch), 1)
batch_args = {
"batch_size": len(batch),
"max_len_in_batch": batch.nopad_max_len_in_batch,
"block_loc": batch.nopad_b_loc,
"start_loc": batch.nopad_b_start_loc,
"seq_len": batch.nopad_b_seq_len,
"cache_manager": batch.cache_manager,
"is_context_stage": is_prefill,
}
infer_state = BatchInferState(**batch_args)
model = self.model
if isinstance(model, LlamaForCausalLM):
model = self.model.model
elif isinstance(model, BloomForCausalLM):
model = self.model.transformer
setattr(model, "infer_state", infer_state)
output = self.model.forward(input_ids=input_)
logits = output.logits
# bsz, seq_len, vocab_size
prob_out = torch.softmax(
logits[
:,
-1,
],
dim=-1,
).squeeze(1)
# prob_out: bsz, vocab_size
predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True)
prob_out = torch.log(prob_out).detach().cpu().numpy()
predict_ids = predict_ids.detach().cpu().numpy()
# [ batch_size, 1 ]
output_dict = {}
new_input_ids = []
for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate(
zip(batch.requests, batch.all_input_ids, predict_ids, prob_out)
):
next_token_id = int(next_token_id)
next_token_logprob = next_token_logprob[next_token_id]
# all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda")
all_input_ids.append(next_token_id)
# all_input_ids_tensor = None
new_input_ids.append(next_token_id)
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] += 1
batch.out_token_id_counts[i][next_token_id] += 1
metadata = {
"id": int(next_token_id),
"logprob": float(next_token_logprob),
}
output_dict[r["request_id"]] = (int(next_token_id), metadata)
batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda()
batch.nopad_total_token_num += len(batch)
batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this
self.cache[batch.batch_id] = batch
return output_dict
@torch.no_grad()
def _prefill_batch(self, batch_id):
return self.forward(batch_id, is_prefill=True)
@torch.no_grad()
def _decode_batch(self, batch_id):
return self.forward(batch_id, is_prefill=False)
# might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text

@ -32,7 +32,7 @@ class MemoryManager:
):
self.logger = logging.get_logger(__name__)
self.available_size = size
self.past_key_values_length = 0
self.max_len_in_batch = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
@ -102,5 +102,5 @@ class MemoryManager:
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
self.past_key_values_length = 0
self.max_len_in_batch = 0
self.logger.info("freed all space of memory manager")

@ -133,17 +133,11 @@ class BloomInferenceForwards:
assert hasattr(self, "infer_state")
infer_state = self.infer_state
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
# if self.cache_manager.past_key_values_length > 0:
if infer_state.cache_manager.past_key_values_length > 0:
# update the past key values length in cache manager,
# NOTE use BatchInferState.past_key_values_length instead the one in cache manager
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length_with_past + past_key_values_length
# infer_state.cache_manager = self.cache_manager
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
if use_cache and seq_length != 1:
# prefill stage
@ -160,21 +154,19 @@ class BloomInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
@ -195,6 +187,7 @@ class BloomInferenceForwards:
past_key_values_length=past_key_values_length,
)
infer_state.decode_layer_id = 0
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -228,6 +221,7 @@ class BloomInferenceForwards:
infer_state=infer_state,
)
infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
@ -247,7 +241,7 @@ class BloomInferenceForwards:
# and update these information in engine.generate after model foward called
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.decode_layer_id = 0
infer_state.max_len_in_batch += 1
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@ -453,9 +447,6 @@ class BloomInferenceForwards:
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
if layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_length # += 1
if infer_state.is_context_stage:
# context process
max_input_len = q_length
@ -506,15 +497,12 @@ class BloomInferenceForwards:
b_loc,
b_start_loc,
b_seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
# update layer id
infer_state.decode_layer_id += 1
# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None

@ -19,8 +19,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
from ._utils import copy_kv_to_mem_cache
try:
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
context_attention_fwd as lightllm_llama2_context_attention_fwd,
)
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
@ -118,13 +121,12 @@ class ChatGLM2InferenceForwards:
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
past_key_values_length = 0
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length + past_key_values_length
infer_state.seq_length_with_past = seq_length_with_past
# prefill stage at first
if use_cache and seq_length != 1:
@ -272,7 +274,6 @@ class ChatGLM2InferenceForwards:
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
infer_state.cache_manager.past_key_values_length += seq_length
if not return_dict:
return tuple(
@ -487,7 +488,7 @@ class ChatGLM2InferenceForwards:
attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
infer_state.start_loc,
infer_state.seq_len,
infer_state.seq_length_with_past,
infer_state.max_len_in_batch,
)
else:

@ -74,12 +74,11 @@ class LlamaInferenceForwards:
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch_size = input_ids.shape[0] # input_ids.shape[0]
infer_state = self.infer_state
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
@ -90,15 +89,10 @@ class LlamaInferenceForwards:
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
# NOT READY FOR PRIME TIME
# dummy but work, revise it
past_key_values_length = infer_state.cache_manager.past_key_values_length
# past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if infer_state.is_context_stage:
past_key_values_length = 0
else:
past_key_values_length = infer_state.max_len_in_batch - 1
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@ -118,23 +112,23 @@ class LlamaInferenceForwards:
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
print(
f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
)
print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
@ -146,11 +140,12 @@ class LlamaInferenceForwards:
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()
infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@ -158,7 +153,7 @@ class LlamaInferenceForwards:
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
(batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
@ -173,7 +168,6 @@ class LlamaInferenceForwards:
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
# NOTE: modify here for passing args to decoder layer
@ -197,8 +191,9 @@ class LlamaInferenceForwards:
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@ -224,7 +219,6 @@ class LlamaInferenceForwards:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
@ -280,11 +274,8 @@ class LlamaInferenceForwards:
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
if infer_state.decode_layer_id == 0: # once per model.forward
infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
# print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
@ -295,7 +286,6 @@ class LlamaInferenceForwards:
if infer_state.is_context_stage:
# first token generation
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
@ -304,7 +294,6 @@ class LlamaInferenceForwards:
infer_state.context_mem_index,
infer_state.cache_manager,
)
attn_output = torch.empty_like(query_states)
if self.num_key_value_groups == 1:
@ -315,7 +304,7 @@ class LlamaInferenceForwards:
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
lightllm_llama2_context_attention_fwd(
@ -325,7 +314,7 @@ class LlamaInferenceForwards:
attn_output,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
if infer_state.decode_is_contiguous:
@ -363,7 +352,7 @@ class LlamaInferenceForwards:
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
)
else:
Llama2TokenAttentionForwards.token_attn(
@ -374,7 +363,7 @@ class LlamaInferenceForwards:
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
infer_state.cache_manager.past_key_values_length,
infer_state.max_len_in_batch,
infer_state.other_kv_index,
)

@ -2,7 +2,6 @@ try:
import triton
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")

@ -10,7 +10,6 @@ except ImportError:
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit
def _fwd_copy_kv_cache_dest(
@ -53,7 +52,6 @@ if HAS_TRITON:
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr,
dest_index_ptr,

@ -18,4 +18,6 @@ SentencePiece
ninja
flash_attn==2.0.5
datasets
pydantic
ray
#auto-gptq now not support torch1.12

@ -11,6 +11,8 @@ ninja
torch>=1.12
safetensors
einops
pydantic
ray
sentencepiece
google
protobuf

@ -27,8 +27,10 @@ if HAS_LLAMA:
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
input_ids = torch.Tensor(
[[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm

@ -52,7 +52,6 @@ def run_chatglm2_test(test_config):
"attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
}
outputs = infer_engine.generate(input_tokens, **generate_kwargs)
assert outputs is not None

@ -0,0 +1,14 @@
engine_config:
model: MODEL_PATH
tensor_parallel_size: 1
max_batch_size: 2
max_input_len: 1024
max_output_len: 512
# config for app router deployment
# Resources assigned to each model replica. This should correspond to Ray AIR ScalingConfig.
router_config:
max_total_token_num: 4096
batch_max_tokens: 4096
disable_log_stats: False
log_stats_interval: 10
model: MODEL_PATH

@ -0,0 +1,61 @@
import asyncio
import os
import uuid
import pytest
import colossalai
from colossalai.inference.async_engine import Async_Engine
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
PATH = "config.yaml"
def run_async_engine(path: str):
if not os.path.exists(path):
return
config = RayInitConfig.from_yaml_path(path)
engine_config = config.engine_config_data
model = engine_config.model
if model is None or not os.path.exists(model):
return
prompt = "Introduce some landmarks in London.\n The Tower of London is a historic castle on the north bank of the River Thames in central London. It was founded towards the end of 10"
sampling_params = SamplingParams()
asyncio.run(asy_for_loop_test(config, prompt, sampling_params))
async def get_result(engine, prompt, sampling_params):
request_id = str(uuid.uuid4().hex)
results = engine.generate(request_id, prompt, sampling_params)
async for result in results:
# print(result)
assert result is not None
async def asy_for_loop_test(config, prompt, sampling_params):
router_config = config.router_config_data
engine_config = config.engine_config_data
engine = Async_Engine(router_config=router_config, engine_config=engine_config)
for i in range(10):
print("in for loop", i)
await get_result(engine, prompt, sampling_params)
def check_async_engine(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_async_engine(PATH)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_async_engine():
spawn(check_async_engine, 1)
if __name__ == "__main__":
test_async_engine()

@ -0,0 +1,95 @@
import pytest
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import DynamicBatchManager
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
BATCH_SIZE = 2
MAX_INPUT_LEN = 48
MAX_OUTPUT_LEN = 256
def run():
sampling_params = SamplingParams()
req1 = Req(0, [1], sampling_params)
req2 = Req(1, [2], sampling_params)
req3 = Req(2, [3], sampling_params)
# req 1-3 are initiliazed as token forward requests
req4 = Req(3, [10, 10, 10, 9, 1], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
# init model and tp engine
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
dynamic_batch_manager = DynamicBatchManager(
tp_engine=infer_engine,
max_total_token_num=640,
batch_max_tokens=608,
eos_id=0,
log_stats=False,
log_stats_interval=10,
waiting_req_list=waiting_list,
model="llama",
)
before_add = len(dynamic_batch_manager.req_queue)
# test add req function
dynamic_batch_manager.add_req(req4.request_id, req4.prompt_ids, req4.sample_params)
assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1
# test abort function
dynamic_batch_manager.abort(req4.request_id)
assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True
# test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested
batch = dynamic_batch_manager.req_queue.generate_new_batch()
assert len(batch) == 2
dynamic_batch_manager._init_batch(batch)
assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None
batch.reqs[0].has_generate_finished = True
# filter one finished
batch.filter_finished()
dynamic_batch_manager._filter_batch(batch)
assert len(dynamic_batch_manager.engine.cache) == 1
# test merge batch
new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch)
assert len(new_batch) == 1
dynamic_batch_manager._init_batch(new_batch)
dynamic_batch_manager._merge_batch(batch, new_batch)
assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2
def check_dynamic_batching_manager(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching_manager():
spawn(check_dynamic_batching_manager, 1)
if __name__ == "__main__":
test_dynamic_batching_manager()

@ -0,0 +1,84 @@
from dataclasses import dataclass
import pytest
import torch
from packaging import version
from transformers import LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
import colossalai
from colossalai.inference.dynamic_batching.io_struct import Req
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.inference.manager import start_dynamic_batching
from colossalai.inference.tensor_parallel import TPInferEngine
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
TP_SIZE = 1
MAX_BATCH_SIZE = 2
MAX_INPUT_LEN = 5
MAX_OUTPUT_LEN = 16
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
@dataclass
class args:
max_total_token_num: int
batch_max_tokens: int
model: str
eos_id: int
disable_log_stats: bool
log_stats_interval: int
def run():
arg = args(
max_total_token_num=42,
model="llama",
batch_max_tokens=42,
eos_id=0,
disable_log_stats=False,
log_stats_interval=10,
)
sampling_params = SamplingParams()
req1 = Req(0, [0, 0, 10, 6, 8], sampling_params)
req2 = Req(1, [10, 10, 10, 10, 10], sampling_params)
req3 = Req(2, [0, 0, 10, 10, 10], sampling_params)
req4 = Req(3, [0, 0, 10, 10, 10], sampling_params)
waiting_list = []
waiting_list.append(req1)
waiting_list.append(req2)
waiting_list.append(req3)
waiting_list.append(req4)
llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=30000, hidden_size=1024)
model = LlamaForCausalLM(llama_config)
model = model.half()
shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True)
infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
batch_manager = start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list)
ans_gen = batch_manager.generate(request_id=5, prompts="hello", sampling_params=sampling_params)
for result in ans_gen:
assert result is not None
def check_dynamic_forward(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run()
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_dynamic_batching():
spawn(check_dynamic_forward, TP_SIZE)
if __name__ == "__main__":
test_dynamic_batching()

@ -0,0 +1,66 @@
import asyncio
import os
import uuid
import pytest
import colossalai
from colossalai.inference.dynamic_batching.ray_dist_init import Driver
from colossalai.inference.dynamic_batching.ray_init_config import RayInitConfig
from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
PATH = "config.yaml"
def run_ray_dist(path: str):
if not os.path.exists(path):
return
config = RayInitConfig.from_yaml_path(path)
router_config = config.router_config_data
engine_config = config.engine_config_data
model = engine_config.model
if model is None or not os.path.exists(model):
return
driver = Driver(router_config=router_config, engine_config=engine_config)
prompt = "Introduce some landmarks in Beijing"
request_id = str(uuid.uuid4().hex)
sampling_params = SamplingParams()
print("sampling_params: ", sampling_params)
async def get_result(request_id, prompt, sampling_params):
return await driver.async_generate(request_id, prompt, sampling_params)
for test_async in [True, False]:
if test_async:
print("test_async: ", test_async)
result = asyncio.run(get_result(request_id, prompt, sampling_params))
assert result is not None
print("result: ", result)
else:
print("test_async: ", test_async)
result = driver.generate(request_id, prompt, sampling_params)
assert result is not None
print("result: ", result)
is_running = None
is_running = driver.is_running()
assert is_running is not None
print("is_running: ", is_running)
def check_ray_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_ray_dist(PATH)
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_ray_dist():
spawn(check_ray_dist, 1)
if __name__ == "__main__":
test_ray_dist()
Loading…
Cancel
Save