diff --git a/colossalai/inference/dynamic_batching/__init__.py b/colossalai/inference/dynamic_batching/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/inference/dynamic_batching/infer_batch.py b/colossalai/inference/dynamic_batching/infer_batch.py new file mode 100644 index 000000000..826272db3 --- /dev/null +++ b/colossalai/inference/dynamic_batching/infer_batch.py @@ -0,0 +1,346 @@ +import collections +from dataclasses import dataclass +from typing import Dict, List , Tuple + +import numpy as np +import torch + +from colossalai.inference.tensor_parallel import MemoryManager + +# make batch infer state an attr of InferBatch + + +class InferSamplingParams: + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + vocab_size: int = -1, + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + if self.top_k == -1: + self.top_k = vocab_size + return + + +@dataclass +class InferBatch: + batch_id: int + requests: List + requests_idx_mapping: Dict[int, int] + + input_ids: torch.Tensor + + all_input_ids: List[List[int]] + input_lengths: List[int] + + out_token_id_counts: List + sampling_param_list: List[InferSamplingParams] + + nopad_total_token_num: int + nopad_max_len_in_batch: int + nopad_b_loc: torch.Tensor + nopad_b_start_loc: torch.Tensor + nopad_b_seq_len: torch.Tensor + cache_manager: MemoryManager + max_total_len: int + + @classmethod + @torch.no_grad() + def init_batch( + cls, + batch_id, + requests, + dtype: torch.dtype, + device: torch.device, + cache_manager: MemoryManager, + vocab_size: int, + max_total_len: int, + ) -> 'InferBatch': + input_lengths = [] + all_input_ids = [] + requests_idx_mapping = {} + + out_token_id_counts = [] + sampling_param_list = [] + + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda") + # to avoid memory leak , we pre-allocate 12 more space for each batch. + nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda") + for i, r in enumerate(requests): + # request id -> idx in list mapping + requests_idx_mapping[r["request_id"]] = i + + tokenized_input = r["input_id"] + + input_length = len(tokenized_input) + input_lengths.append(input_length) + all_input_ids.append(tokenized_input) + out_token_id_counts.append(collections.defaultdict(int)) + + # postprocessor + sampling_param = r["sampling_param"] + sampling_param["vocab_size"] = vocab_size + sampling_param_list.append(InferSamplingParams(**sampling_param)) + + nopad_total_token_num += input_length + nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length) + + nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda") + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + + if len(requests) > 1: + input_ids = np.concatenate(all_input_ids, dtype=np.int64) + else: + input_ids = all_input_ids[0] + + # Create tensors on device + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device) + + return cls( + batch_id=batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=cache_manager, + max_total_len=max_total_len, + ) + + @torch.no_grad() + def free_self(self) -> None: + """ + Free the memory of the InferBatch itself + """ + remove_index = [] + for idx in range(len(self)): + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + + @torch.no_grad() + def filter(self, request_ids: List[int]) -> 'InferBatch': + """ + Filter finished batch and return a new InferBatch with left ones. + """ + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + requests_idx_mapping = {} + indices = [] + requests = [] + all_input_ids = [] + input_lengths = [] + nopad_total_token_num = 0 + nopad_max_len_in_batch = 0 + nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda") + + left_idx = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + left_idx.append(idx) + + left_idx_set = set(left_idx) + remove_index = [] + for idx in range(len(self)): + if idx not in left_idx_set: + remove_index.append( + self.nopad_b_loc[ + idx, + (self.nopad_max_len_in_batch - 1) + - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1), + ] + ) + remove_index = torch.cat(remove_index, dim=-1) + self.cache_manager.free(remove_index) + + nopad_max_len_in_batch = 0 + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + indices.append(idx) + + nopad_b_seq_len[:] = self.nopad_b_seq_len[indices] + nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item() + nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1] + nopad_total_token_num = torch.sum(nopad_b_seq_len).item() + + nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[ + indices, + (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1), + ] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + requests.append(self.requests[idx]) + all_input_ids.append(self.all_input_ids[idx]) + input_lengths.append(self.input_lengths[idx]) + + input_ids = self.input_ids[indices] + + return InferBatch( + batch_id=self.batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices], + sampling_param_list=[self.sampling_param_list[_i] for _i in indices], + cache_manager=self.cache_manager, + max_total_len=self.max_total_len, + ) + + @classmethod + @torch.no_grad() + def merge(cls, batch1, batch2) -> 'InferBatch': + """ + Return megerd new InferBatch + """ + requests = batch1.requests + batch2.requests + requests_idx_mapping = {} + new_batch_size = len(batch1) + len(batch2) + + input_ids = batch1.input_ids.new_empty(new_batch_size) + all_input_ids = [] + input_lengths = [] + out_token_id_counts = [] + sampling_param_list = [] + + cumulative_batch_size = 0 + nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num + nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch) + max_total_len = max(batch1.max_total_len, batch2.max_total_len) + nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda") + nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda") + nopad_start_loc_len_temp = 0 + batches = [batch1, batch2] + for i, batch in enumerate(batches): + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + cumulative_batch_size + start_index = cumulative_batch_size + end_index = cumulative_batch_size + len(batch) + input_ids[start_index:end_index] = batch.input_ids + nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len + nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp + nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1] + nopad_b_loc[ + start_index:end_index, + nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1, + ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1] + + all_input_ids.extend(batch.all_input_ids) + + input_lengths.extend(batch.input_lengths) + out_token_id_counts.extend(batch.out_token_id_counts) + sampling_param_list.extend(batch.sampling_param_list) + # Update + cumulative_batch_size += len(batch) + + nopad_b_loc[:, nopad_max_len_in_batch - 1] = ( + nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda") + ) + return InferBatch( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + input_lengths=input_lengths, + all_input_ids=all_input_ids, + nopad_total_token_num=nopad_total_token_num, + nopad_max_len_in_batch=nopad_max_len_in_batch, + nopad_b_loc=nopad_b_loc, + nopad_b_start_loc=nopad_b_start_loc, + nopad_b_seq_len=nopad_b_seq_len, + out_token_id_counts=out_token_id_counts, + sampling_param_list=sampling_param_list, + cache_manager=batches[0].cache_manager, + max_total_len=max_total_len, + ) + + def __len__(self): + return len(self.requests) + + def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + top_ks: List[int] = [] + p_token_ids: List[int] = [] + p_token_counts: List[int] = [] + p_seq_len: List[int] = [ + 0, + ] + p_max_len_in_batch: int = 0 + for i, id_to_count in enumerate(self.out_token_id_counts): + sample_param = self.sampling_param_list[i] + presence_penalties.append(sample_param.presence_penalty) + frequency_penalties.append(sample_param.frequency_penalty) + temperatures.append(sample_param.temperature) + top_ps.append(sample_param.top_p) + top_ks.append(sample_param.top_k) + + for token_id, count in id_to_count.items(): + p_token_ids.append(token_id) + p_token_counts.append(count) + p_seq_len.append(len(id_to_count)) + p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count)) + + presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda") + frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda") + temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda") + top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda") + top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda") + p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda") + p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda") + p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda") + p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32) + return ( + presence_penalties, + frequency_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + ) diff --git a/colossalai/inference/dynamic_batching/io_struct.py b/colossalai/inference/dynamic_batching/io_struct.py new file mode 100644 index 000000000..2b2739f0a --- /dev/null +++ b/colossalai/inference/dynamic_batching/io_struct.py @@ -0,0 +1,149 @@ +from typing import Dict, List, Tuple + +from .sampling_params import SamplingParams + + +class Req: + def __init__(self, request_id, prompt_ids, sample_params: SamplingParams): + self.request_id = request_id + self.prompt_ids = prompt_ids + self.input_len = len(prompt_ids) + self.max_output_len = sample_params.max_new_tokens + self.sample_params = sample_params + self.output_ids = [] + self.output_metadata_list = [] + self.has_generate_finished = False + self.aborted = False + + def to_rpc_obj(self): + return { + "request_id": self.request_id, + "input_id": self.prompt_ids, + "output_len": self.max_output_len, + "sampling_param": self.sample_params.to_dict(), + } + + def to_req_detokenization_state(self): + out = ReqDetokenizationState( + self.request_id, self.prompt_ids, self.max_output_len, self.sample_params.ignore_eos + ) + if self.output_metadata_list: + out.gen_metadata.update(self.output_metadata_list[-1]) + return out + + def stop_sequences_matched(self): + # should we add stpp sequences to the sample params? + if self.sample_params.stop_sequences is not None: + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if stop_len > 0 and len(self.output_ids) >= stop_len and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True + return False + + def __repr__(self): + return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, " + + +class ReqDetokenizationState: + def __init__( + self, + request_id: str, + prompt_ids: List[int], + max_output_len: int, + ignore_eos: bool, + ) -> None: + self.request_id = request_id + self.prompt_ids = prompt_ids + self.output_ids = [] + self.output_tokens = [] + self.output_str = "" + self.sub_texts = [] + self.current_sub_text = [] + self.max_output_len = max_output_len + self.ignore_eos = ignore_eos + self.gen_metadata = {} + + +class Batch: + def __init__(self, batch_id, reqs: List[Req]): + self.batch_id = batch_id + self.reqs = reqs + self.id_to_reqs = {req.request_id: req for req in reqs} + + def input_tokens(self): + batch_input_tokens = 0 + for req in self.reqs: + batch_input_tokens += req.input_len + return batch_input_tokens + + def calcu_max_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + req.max_output_len + return tokens + + def calcu_used_tokens(self): + tokens = 0 + for req in self.reqs: + tokens += req.input_len + len(req.output_ids) + return tokens + + def mark_finished_req(self, eos_id): + has_new_finish = False + for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True + if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: + req.has_generate_finished = True + has_new_finish = True + if len(req.output_ids) >= req.max_output_len or req.aborted: + req.has_generate_finished = True + has_new_finish = True + return has_new_finish + + def filter_finished(self): + """ + Filter finished requests from the batch, the finished ones will be removed from 'reqs'. + """ + # TODO: the logic of return should be defined here. + unfinished_req = [] + for req in self.reqs: + if not req.has_generate_finished: + unfinished_req.append(req) + self.reqs = unfinished_req + self.id_to_reqs = {req.request_id: req for req in self.reqs} + + def is_clear(self): + return len(self.reqs) == 0 + + def merge(self, mini_batch): + for _req in mini_batch.reqs: + self.reqs.append(_req) + self.id_to_reqs = {req.request_id: req for req in self.reqs} + return + + def __repr__(self): + return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, " + + def __len__(self): + return len(self.reqs) + + +class BatchTokenIdOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, int, Dict, bool, bool] + ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state] + + +class BatchStrOut: + def __init__(self): + self.reqs_infs: List[ + Tuple[str, str, Dict, bool, bool] + ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state] + + +class AbortReq: + def __init__(self, req_id): + self.req_id = req_id diff --git a/colossalai/inference/dynamic_batching/req_queue.py b/colossalai/inference/dynamic_batching/req_queue.py new file mode 100644 index 000000000..d9e9b6269 --- /dev/null +++ b/colossalai/inference/dynamic_batching/req_queue.py @@ -0,0 +1,71 @@ +import uuid +from typing import List + +import numpy as np + +from .io_struct import Batch, Req + + +class ReqQueue: + def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None: + self.max_total_tokens = max_total_tokens + assert batch_max_tokens is not None + self.batch_max_tokens = batch_max_tokens + self.running_max_req_size = running_max_req_size + self.waiting_req_list: List[Req] = waiting_req_list + + def append(self, req): + self.waiting_req_list.append(req) + return + + def _init_cache_list(self, current_batch: Batch): + if current_batch is not None: + self.cache_len_list = [ + (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1) + for req in current_batch.reqs + ] + else: + self.cache_len_list = [] + + # @calculate_time(show=True, min_cost_ms=0.1) + def _can_add_new_req(self, req): + self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis + self.cache_len_list.sort(key=lambda x: -x[1]) + + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + # assert left_out_len_array.min() >= 0 + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + # NOTE: change here < to <= + return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size + + def generate_new_batch(self, current_batch: Batch = None): + if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size: + return None + self._init_cache_list(current_batch) + can_run_list = [] + new_batch_total_tokens = 0 + aborted_count = 0 + for req in self.waiting_req_list: + flag = self._can_add_new_req(req) + if req.aborted: + aborted_count += 1 + continue + if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens: + can_run_list.append(req) + new_batch_total_tokens += req.input_len + else: + break + + if len(can_run_list) != 0: + new_batch = Batch(uuid.uuid4().hex, can_run_list) + self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :] + return new_batch + else: + return None + + def __len__(self): + return self.waiting_req_list.__len__() diff --git a/colossalai/inference/dynamic_batching/sampling_params.py b/colossalai/inference/dynamic_batching/sampling_params.py new file mode 100644 index 000000000..9a0ace411 --- /dev/null +++ b/colossalai/inference/dynamic_batching/sampling_params.py @@ -0,0 +1,82 @@ +"""Sampling parameters for text generation.""" +from typing import List, Optional, Union + +_SAMPLING_EPS = 1e-5 + + +class SamplingParams: + + def __init__( + self, + do_sample: bool = False, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, # -1 is for all + ignore_eos: bool = False, + max_new_tokens: int = 16, + stop_sequences: Optional[Union[str, List[str]]] = None # conditions to stop generation + ) -> None: + self.do_sample = do_sample + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.ignore_eos = ignore_eos + self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences + if self.do_sample == False: + self.temperature = 1.0 + self.top_p = 1.0 + self.top_k = 1 + if self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS: # temperature is too slow, change to greedy search + self.temperature = 1.0 + self.top_k = 1 + return + + def verify(self): + if self.presence_penalty < 0.0: + raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}") + if self.frequency_penalty < 0.0: + raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}") + if self.temperature <= 0.0: + raise ValueError(f"temperature must > 0.0, got {self.temperature}") + if self.top_p <= 0.0 or self.top_p > 1.0: + raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return + + def to_dict(self): + ret = {} + ret["do_sample"] = self.do_sample + ret["presence_penalty"] = self.presence_penalty + ret["frequency_penalty"] = self.frequency_penalty + ret["temperature"] = self.temperature + ret["top_p"] = self.top_p + ret["top_k"] = self.top_k + # if self.ignore_eos is not None: + # ret["ignore_eos"] = self.ignore_eos + # if self.max_tokens is not None: + # ret["max_tokens"] = self.max_tokens + return ret diff --git a/colossalai/inference/dynamic_batching/stats.py b/colossalai/inference/dynamic_batching/stats.py new file mode 100644 index 000000000..6d34183f4 --- /dev/null +++ b/colossalai/inference/dynamic_batching/stats.py @@ -0,0 +1,43 @@ +import time + + +class Stats: + def __init__(self, log_status, log_stats_interval) -> None: + self.log_stats = log_status + self.log_stats_interval = log_stats_interval + self.last_log_time = time.time() + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + return + + def count_prompt_tokens(self, run_batch): + if self.log_stats: + tokens = run_batch.input_tokens() + self.prompt_tokens += tokens + self.all_tokens += tokens + return + + def count_output_tokens(self, run_batch): + if self.log_stats: + tokens = len(run_batch.reqs) + self.output_tokens += tokens + self.all_tokens += tokens + return + + def print_stats(self): + if not self.log_stats: + return + + now = time.time() + if now - self.last_log_time > self.log_stats_interval: + print( + f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" + f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" + ) + self.all_tokens = 0 + self.output_tokens = 0 + self.prompt_tokens = 0 + self.last_log_time = now + return diff --git a/colossalai/inference/manager.py b/colossalai/inference/manager.py new file mode 100644 index 000000000..72f774067 --- /dev/null +++ b/colossalai/inference/manager.py @@ -0,0 +1,243 @@ +import time +from typing import List + +from .dynamic_batching.infer_batch import InferBatch +from .dynamic_batching.io_struct import Batch, Req +from .dynamic_batching.req_queue import ReqQueue +from .dynamic_batching.sampling_params import SamplingParams +from .dynamic_batching.stats import Stats +from .tensor_parallel import TPInferEngine + + +class DynamicBatchManager: + def __init__( + self, + tp_engine: TPInferEngine, + max_total_token_num, + batch_max_tokens, + eos_id, + log_stats=True, + log_stats_interval=10, + running_batch: Batch = None, + waiting_req_list: List = [], + ): + """ + Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager + max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len) + batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests + running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine + eos_id : The end token of a seq + log_stats : whether to log stats + log_stats_interval : log stats interval + running_batch : running batch + waiting_req_list : list of waiting requests, initialized before dynamic batch manager + """ + self.engine = tp_engine + self.max_total_token_num = max_total_token_num + running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2 + self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list) + # all the inputs should be put into req_queue: waiting req list + + self.running_batch: Batch = running_batch + self.eos_id = eos_id + self.has_wait_tokens = 0 + self.max_wait_tokens = 10 + + self.stats_tool = Stats(log_stats, log_stats_interval) + self.mem_usage_interval = log_stats_interval * 2 + + def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str): + """ + Add new request to req queue, during initialization all requests are held in waiting list. + """ + req = Req(request_id, prompt_ids, sampling_params) + self.req_queue.append(req) + return + + def abort(self, request_id): + if self.running_batch is not None: + for req in self.running_batch.reqs: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + for req in self.req_queue.waiting_req_list: + if req.request_id == request_id: + req.has_generate_finished = True + req.aborted = True + return + + def loop_for_fwd(self): + """ + The main loop for a dynamic batching process. + """ + counter_count = 0 + while self.running_batch is not None or self.req_queue.waiting_req_list: + self._step() + counter_count += 1 + if self.running_batch is not None: + if counter_count % self.mem_usage_interval == 0: + print( + "current batch size:", + len(self.running_batch.reqs), + "token used ratio:", + self.running_batch.calcu_used_tokens() / self.max_total_token_num, + ) + self.stats_tool.print_stats() + + if self.running_batch is None: + time.sleep(0.1) # 10ms + + def _step(self): + """ + Logic for handling requests + """ + + if self.running_batch is None: + new_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_batch is not None: + self.stats_tool.count_prompt_tokens(new_batch) + self.running_batch = new_batch + self._prefill_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens = 0 + return + + if self.has_wait_tokens < self.max_wait_tokens: + self.stats_tool.count_output_tokens(self.running_batch) + self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + return + else: + new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + if new_mini_batch is not None: + self.stats_tool.count_prompt_tokens(new_mini_batch) + self._prefill_batch(new_mini_batch) + if not new_mini_batch.is_clear(): + self._merge_batch(self.running_batch, new_mini_batch) + self.running_batch.merge(new_mini_batch) + self.has_wait_tokens = 0 + else: + self.stats_tool.count_output_tokens(self.running_batch) + self._decode_batch(self.running_batch) + self._filter_runing_batch() + self.has_wait_tokens += 1 + + return + + def _init_batch(self, batch: Batch, dtype="fp16"): + reqs = [r.to_rpc_obj() for r in batch.reqs] + batch_id = batch.batch_id + + import torch + + if dtype == "fp16": + dtype = torch.float16 + else: + assert False, "error dtype" + + batch_data = InferBatch.init_batch( + batch_id, + reqs, + dtype, + torch.cuda.current_device(), + self.engine.cache_manager, + self.engine.model.config.vocab_size, + self.engine.max_input_len + self.engine.max_output_len, + ) + self.engine.cache[batch_id] = batch_data + + def _prefill_batch(self, batch): + """ + For all batches, no matter it is a new batch or a mini batch, we need to do prefill first. + """ + self._init_batch(batch) + + # TODO: figure out if cache and batch id is needed + ans = self.engine._prefill_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + self._handle_finish_req(batch, has_new_finished_req) + # delete finished reqs + + def _decode_batch(self, batch: Batch): + """ + Decoding process + """ + ans = self.engine._decode_batch(batch.batch_id) + req_to_out_token_id = ans + self._add_token_id_to_req(batch, req_to_out_token_id) + has_new_finished_req = batch.mark_finished_req(self.eos_id) + self._handle_finish_req(batch, has_new_finished_req) + + def _filter_batch(self, batch: Batch): + batch_id = batch.batch_id + req_id_list = [r.request_id for r in batch.reqs] + batch = self.engine.cache.pop(batch_id) + filter_batch = batch.filter(req_id_list) + del batch + self.engine.cache[batch_id] = filter_batch + + def _merge_batch(self, batch1, batch2): + """ + Merge new mini batch into running batch. + """ + batch1 = self.engine.cache.pop(batch1.batch_id) + batch2 = self.engine.cache.pop(batch2.batch_id) + + m_batch = InferBatch.merge(batch1, batch2) + self.engine.cache[batch1.batch_id] = m_batch + del batch1 + del batch2 + + def _remove_batch(self, batch): + """ + Remove finished batch. + """ + batch = self.engine.cache.pop(batch.batch_id) + batch.free_self() + del batch + + def _handle_finish_req(self, batch: Batch, has_new_finished_req): + if has_new_finished_req: + batch.filter_finished() + if batch.is_clear(): + self._remove_batch(batch) + else: + self._filter_batch(batch) + + def _filter_runing_batch(self): + if self.running_batch is not None and self.running_batch.is_clear(): + self.running_batch = None + + def _add_token_id_to_req(self, batch: Batch, req_ans): + for req_id, (new_token_id, new_gen_metadata) in req_ans.items(): + req = batch.id_to_reqs[req_id] + req.output_ids.append(new_token_id) + req.output_metadata_list.append(new_gen_metadata) + return + + def clean_up(self): + # this logic should be implemented in the future. + pass + + +def start_dynamic_batching(args, tp_engine, waiting_req_list): + # try: + batch_manager = DynamicBatchManager( + tp_engine=tp_engine, + max_total_token_num=args.max_total_token_num, + batch_max_tokens=args.batch_max_tokens, + eos_id=args.eos_id, + log_stats=not args.disable_log_stats, + log_stats_interval=args.log_stats_interval, + waiting_req_list=waiting_req_list, + ) + + # except Exception: + # batch_manager.clean_up() + # raise + + batch_manager.loop_for_fwd() + return diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index d5ef37fee..f7fb7a825 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -1,7 +1,6 @@ from typing import Any, Callable, List, Optional, Union import torch -import torch.distributed as dist import torch.nn as nn from transformers import BloomForCausalLM, LlamaForCausalLM from transformers.generation import GenerationConfig @@ -14,6 +13,8 @@ from colossalai.shardformer.policies.auto_policy import get_autopolicy from .batch_infer_state import BatchInferState from .kvcache_manager import MemoryManager +# from dynamic_batching.infer_batch import InferBatch + DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2 _supported_models = [ @@ -90,6 +91,8 @@ class TPInferEngine: self.shard_config = shard_config self.model = None + self.cache = {} + # optimize the original model by sharding with ShardFormer self._optimize_model(model=model.to(device)) @@ -116,13 +119,15 @@ class TPInferEngine: def _post_init_gptq_buffer(self, model: nn.Module) -> None: from colossalai.inference.quant.gptq.cai_gptq import CaiQuantLinear + HAS_GPTQ_CUDA = False try: from colossalai.kernel.op_builder.gptq import GPTQBuilder + gptq_cuda = GPTQBuilder().load() HAS_GPTQ_CUDA = True except ImportError: - warnings.warn('CUDA gptq is not installed') + warnings.warn("CUDA gptq is not installed") HAS_GPTQ_CUDA = False for name, submodule in model.named_modules(): @@ -130,8 +135,9 @@ class TPInferEngine: self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8) if self.use_act_order: - self.max_inner_outer_dim = max(self.max_inner_outer_dim, submodule.infeatures, - submodule.outfeatures) + self.max_inner_outer_dim = max( + self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures + ) self.bits = submodule.bits if not (HAS_GPTQ_CUDA and self.bits == 4): return @@ -141,15 +147,16 @@ class TPInferEngine: max_input_len = self.max_input_len # The temp_state buffer is required to reorder X in the act-order case. # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - self.gptq_temp_state_buffer = torch.zeros((max_input_len, self.max_inner_outer_dim), - dtype=torch.float16, - device=torch.cuda.current_device()) - self.gptq_temp_dq_buffer = torch.zeros((1, self.max_dq_buffer_size), - dtype=torch.float16, - device=torch.cuda.current_device()) - - gptq_cuda.prepare_buffers(torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, - self.gptq_temp_dq_buffer) + self.gptq_temp_state_buffer = torch.zeros( + (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device() + ) + self.gptq_temp_dq_buffer = torch.zeros( + (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device() + ) + + gptq_cuda.prepare_buffers( + torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer + ) # Using the default from exllama repo here. matmul_recons_thd = 8 matmul_fused_remap = False @@ -270,7 +277,6 @@ class TPInferEngine: attention_mask = [attention_mask] if attention_mask is not None else attention_mask batch_size = len(input_ids_list) - seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda") seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda") start_index = 0 @@ -304,6 +310,7 @@ class TPInferEngine: batch_infer_state.past_key_values_len = 0 batch_infer_state.is_context_stage = True batch_infer_state.set_cache_manager(self.cache_manager) + return batch_infer_state @torch.no_grad() @@ -367,6 +374,86 @@ class TPInferEngine: infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device) infer_state.seq_len += 1 + @torch.no_grad() + def forward(self, batch_id, is_prefill): + """ + Forward is used in Dynamic Batching Manager + """ + batch = self.cache.pop(batch_id) + + if is_prefill: + input_ = torch.tensor(batch.all_input_ids).cuda() + else: + input_ = batch.input_ids.reshape(len(batch), 1) + + batch_args = { + "batch_size": len(batch), + "max_len_in_batch": batch.nopad_max_len_in_batch, + "block_loc": batch.nopad_b_loc, + "start_loc": batch.nopad_b_start_loc, + "seq_len": batch.nopad_b_seq_len, + "cache_manager": batch.cache_manager, + "is_context_stage": is_prefill, + } + + infer_state = BatchInferState(**batch_args) + model = self.model + if isinstance(model, LlamaForCausalLM): + model = self.model.model + elif isinstance(model, BloomForCausalLM): + model = self.model.transformer + setattr(model, "infer_state", infer_state) + + output = self.model.forward(input_ids=input_) + logits = output.logits + # bsz, seq_len, vocab_size + prob_out = torch.softmax( + logits[ + :, + -1, + ], + dim=-1, + ).squeeze(1) + # prob_out: bsz, vocab_size + predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True) + prob_out = torch.log(prob_out).detach().cpu().numpy() + predict_ids = predict_ids.detach().cpu().numpy() + # [ batch_size, 1 ] + + output_dict = {} + new_input_ids = [] + for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate( + zip(batch.requests, batch.all_input_ids, predict_ids, prob_out) + ): + next_token_id = int(next_token_id) + next_token_logprob = next_token_logprob[next_token_id] + # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda") + all_input_ids.append(next_token_id) + # all_input_ids_tensor = None + new_input_ids.append(next_token_id) + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] += 1 + batch.out_token_id_counts[i][next_token_id] += 1 + metadata = { + "id": int(next_token_id), + "logprob": float(next_token_logprob), + } + output_dict[r["request_id"]] = (int(next_token_id), metadata) + + batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda() + batch.nopad_total_token_num += len(batch) + batch.nopad_max_len_in_batch += 1 + self.cache[batch.batch_id] = batch + return output_dict + + @torch.no_grad() + def _prefill_batch(self, batch_id): + return self.forward(batch_id, is_prefill=True) + + @torch.no_grad() + def _decode_batch(self, batch_id): + return self.forward(batch_id, is_prefill=False) + # might want to create a sequence pool # add a single request/sequence/input text at a time and record its length # In other words, store the actual length of input tokens representing a single input text diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/tensor_parallel/modeling/_utils.py index e476c3132..068b64b4f 100644 --- a/colossalai/inference/tensor_parallel/modeling/_utils.py +++ b/colossalai/inference/tensor_parallel/modeling/_utils.py @@ -45,7 +45,7 @@ def init_to_get_rotary(self, base=10000, use_elem=False): base = float(base) # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ - ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", None)) + ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) if ntk_alpha is not None: ntk_alpha = float(ntk_alpha) diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a7661cee1..958868a09 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -62,12 +62,11 @@ class LlamaInferenceForwards: output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): - batch_size = input_ids.shape[0] # input_ids.shape[0] - infer_state = self.infer_state return_dict = return_dict if return_dict is not None else self.config.use_return_dict + use_cache = use_cache if use_cache is not None else self.config.use_cache # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") @@ -78,15 +77,12 @@ class LlamaInferenceForwards: else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - # NOT READY FOR PRIME TIME - # dummy but work, revise it - past_key_values_length = infer_state.cache_manager.past_key_values_length - # past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length + # NOT READY FOR PRIME TIME + # dummy but work, revise it + if infer_state.is_context_stage: + past_key_values_length = 0 + else: + past_key_values_length = infer_state.max_len_in_batch - 1 # NOTE: differentiate with prefill stage # block_loc require different value-assigning method for two different stage @@ -106,23 +102,23 @@ class LlamaInferenceForwards: infer_state.decode_mem_index = alloc_mem[0] infer_state.decode_mem_start = alloc_mem[1] infer_state.decode_mem_end = alloc_mem[2] - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index else: print(f" *** Encountered allocation non-contiguous") - print( - f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}" - ) + print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}") infer_state.decode_is_contiguous = False alloc_mem = infer_state.cache_manager.alloc(batch_size) infer_state.decode_mem_index = alloc_mem # infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") # infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") - infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index + infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index + if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) + position_ids = position_ids.repeat(batch_size, 1) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -134,6 +130,7 @@ class LlamaInferenceForwards: infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view( position_ids.view(-1).shape[0], -1 ) + else: seq_len = infer_state.seq_len infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1) @@ -145,7 +142,7 @@ class LlamaInferenceForwards: # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device ) attention_mask = self._prepare_decoder_attention_mask( @@ -160,7 +157,6 @@ class LlamaInferenceForwards: next_decoder_cache = () if use_cache else None infer_state.decode_layer_id = 0 - for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None # NOTE: modify here for passing args to decoder layer @@ -184,7 +180,7 @@ class LlamaInferenceForwards: # update indices # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") - infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") + infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda") infer_state.seq_len += 1 if not return_dict: @@ -211,7 +207,6 @@ class LlamaInferenceForwards: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, @@ -267,11 +262,8 @@ class LlamaInferenceForwards: # NOTE might want to revise # need some way to record the length of past key values cache # since we won't return past_key_value_cache right now - if infer_state.decode_layer_id == 0: # once per model.forward - infer_state.cache_manager.past_key_values_length += q_len # seq_len cos, sin = infer_state.position_cos, infer_state.position_sin - # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, ) rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin) rotary_embedding_fwd(key_states.view(-1, self.num_heads, self.head_dim), cos, sin) @@ -282,7 +274,6 @@ class LlamaInferenceForwards: if infer_state.is_context_stage: # first token generation - # copy key and value calculated in current step to memory manager copy_kv_to_mem_cache( infer_state.decode_layer_id, @@ -291,9 +282,7 @@ class LlamaInferenceForwards: infer_state.context_mem_index, infer_state.cache_manager, ) - attn_output = torch.empty_like(query_states) - llama_context_attn_fwd( query_states, key_states, @@ -301,7 +290,7 @@ class LlamaInferenceForwards: attn_output, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) else: if infer_state.decode_is_contiguous: @@ -338,7 +327,7 @@ class LlamaInferenceForwards: infer_state.block_loc, infer_state.start_loc, infer_state.seq_len, - infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index 983069158..070ebe45f 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -2,7 +2,6 @@ try: import triton HAS_TRITON = True - except ImportError: HAS_TRITON = False print("Triton is not installed. Please install Triton to use Triton kernels.") diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py index 02edcc9a9..0520bc111 100644 --- a/colossalai/kernel/triton/copy_kv_cache_dest.py +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -51,7 +51,6 @@ if HAS_TRITON: assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" num_warps = 2 - _fwd_copy_kv_cache_dest[(seq_len,)]( k_ptr, dest_index_ptr, diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py index bc229b17e..1d1e154b6 100644 --- a/tests/kit/model_zoo/transformers/llama.py +++ b/tests/kit/model_zoo/transformers/llama.py @@ -27,8 +27,10 @@ if HAS_LLAMA: # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') # ----------------------------------- - input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() - attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + input_ids = torch.Tensor( + [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]] + ).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long() return dict(input_ids=input_ids, attention_mask=attention_mask) # label is needed for casual lm diff --git a/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py new file mode 100644 index 000000000..124f1f478 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_dynamic_batching_manager.py @@ -0,0 +1,94 @@ +import pytest +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import DynamicBatchManager +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 + + +def run(): + sampling_params = SamplingParams() + + req1 = Req(0, [1], sampling_params) + req2 = Req(1, [2], sampling_params) + req3 = Req(2, [3], sampling_params) + # req 1-3 are initiliazed as token forward requests + req4 = Req(3, [10, 10, 10, 9, 1], sampling_params) + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + + # init model and tp engine + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=False, inference_only=True) + infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + + dynamic_batch_manager = DynamicBatchManager( + tp_engine=infer_engine, + max_total_token_num=42, + batch_max_tokens=42, + eos_id=0, + log_stats=False, + log_stats_interval=10, + waiting_req_list=waiting_list, + ) + before_add = len(dynamic_batch_manager.req_queue) + + # test add req function + dynamic_batch_manager.add_req(req4.prompt_ids, req4.sample_params, req4.request_id) + assert len(dynamic_batch_manager.req_queue.waiting_req_list) == before_add + 1 + + # test abort function + dynamic_batch_manager.abort(req4.request_id) + assert dynamic_batch_manager.req_queue.waiting_req_list[-1].aborted == True + + # test filter batch function, loop_for_fwd, _step, _init_batch and _prefill/_decode batch are tested + batch = dynamic_batch_manager.req_queue.generate_new_batch() + assert len(batch) == 2 + + dynamic_batch_manager._init_batch(batch) + assert dynamic_batch_manager.engine.cache[batch.batch_id] is not None + + batch.reqs[0].has_generate_finished = True + # filter one finished + batch.filter_finished() + dynamic_batch_manager._filter_batch(batch) + assert len(dynamic_batch_manager.engine.cache) == 1 + + # test merge batch + new_batch = dynamic_batch_manager.req_queue.generate_new_batch(batch) + assert len(new_batch) == 1 + dynamic_batch_manager._init_batch(new_batch) + dynamic_batch_manager._merge_batch(batch, new_batch) + + assert len(dynamic_batch_manager.engine.cache[batch.batch_id]) == 2 + + +def check_dynamic_batching_manager(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching_manager(): + spawn(check_dynamic_batching_manager, 1) + + +if __name__ == "__main__": + test_dynamic_batching_manager() diff --git a/tests/test_infer/test_dynamic_batching/test_forward.py b/tests/test_infer/test_dynamic_batching/test_forward.py new file mode 100644 index 000000000..63df491e5 --- /dev/null +++ b/tests/test_infer/test_dynamic_batching/test_forward.py @@ -0,0 +1,70 @@ +import pytest +import torch +from packaging import version +from transformers import LlamaForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig + +import colossalai +from dataclasses import dataclass +from colossalai.inference.dynamic_batching.io_struct import Req +from colossalai.inference.dynamic_batching.sampling_params import SamplingParams +from colossalai.inference.manager import start_dynamic_batching +from colossalai.inference.tensor_parallel import TPInferEngine +from colossalai.shardformer import ShardConfig +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn + +TP_SIZE = 1 +MAX_BATCH_SIZE = 2 +MAX_INPUT_LEN = 5 +MAX_OUTPUT_LEN = 16 +CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5") + +@dataclass +class args: + max_total_token_num: int + batch_max_tokens: int + eos_id: int + disable_log_stats: bool + log_stats_interval: int + + +def run(): + arg = args(max_total_token_num=42, batch_max_tokens=42, eos_id=0, disable_log_stats=False, log_stats_interval=10) + sampling_params = SamplingParams() + + req1 = Req(0, [0, 0, 10, 6, 8], sampling_params) + req2 = Req(1, [10, 10, 10, 10, 10], sampling_params) + req3 = Req(2, [0, 0, 10, 10, 10], sampling_params) + req4 = Req(3, [0, 0, 10, 10, 10], sampling_params) + + waiting_list = [] + waiting_list.append(req1) + waiting_list.append(req2) + waiting_list.append(req3) + waiting_list.append(req4) + + llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024) + model = LlamaForCausalLM(llama_config) + model = model.half() + + shard_config = ShardConfig(enable_tensor_parallelism=True if TP_SIZE > 1 else False, inference_only=True) + + infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) + start_dynamic_batching(arg, tp_engine=infer_engine, waiting_req_list=waiting_list) + + +def check_dynamic_forward(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + run() + + +@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_dynamic_batching(): + spawn(check_dynamic_forward, TP_SIZE) + + +if __name__ == "__main__": + test_dynamic_batching() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index 13bdf0399..b424525a3 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -38,7 +38,6 @@ def run_llama_test(test_config): enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True ) infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN) - init_to_get_rotary(model.model, base=10000) generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False) input_tokens = {