[Inference] ADD async and sync Api server using FastAPI (#5396)

* add api server

* fix

* add

* add completion service and fix bug

* add generation config

* revise shardformer

* fix bugs

* add docstrings and fix some bugs

* fix bugs and add choices for prompt template
feat/online-serving
Jianghai 2024-03-01 14:47:36 +08:00 committed by CjhHa1
parent d482922035
commit 69cd7e069d
13 changed files with 789 additions and 25 deletions

View File

@ -62,6 +62,9 @@ class BatchBucket:
def current_batch_size(self):
return self._current_batch_size
def __len__(self):
return self._current_batch_size
@property
def available_batch_size(self):
return self.max_batch_size - self._current_batch_size

View File

@ -1,10 +1,10 @@
"""
Our config contains various options for inference optimization, it is a unified API that wraps all the configurations for inference.
"""
import dataclasses
import logging
from dataclasses import dataclass
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed as dist
@ -214,3 +214,18 @@ class InferenceConfig:
meta_config[type] = getattr(model_config, type)
return GenerationConfig.from_dict(meta_config)
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "InferenceConfig":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
inference_config_args = {}
for attr in attrs:
if attr in config_dict:
inference_config_args[attr] = config_dict[attr]
else:
inference_config_args[attr] = getattr(cls, attr)
# Set the attributes from the parsed arguments.
inference_config = cls(**inference_config_args)
return inference_config

View File

@ -0,0 +1,318 @@
import asyncio
from functools import partial
from logging import Logger
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
from colossalai.inference.core.engine import InferenceEngine
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task, request_tracker: "RequestTracker") -> None:
msg = "Task finished unexpectedly. This should never happen! "
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class AsyncStream:
"""A stream of Output for a request that can be
iterated over asynchronously."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._finished = False
def put(self, item) -> None:
if self._finished:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopIteration)
self._finished = True
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self):
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[int] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._request_streams
def init_event(self):
self.new_requests_event = asyncio.Event()
def propagate_exception(self, exc: Exception, request_id: Optional[int] = None) -> None:
"""
Propagate an exception to request streams (all if request_id is None).
"""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_finished_request(self, finished_request) -> None:
"""Process a finished request from the engine."""
request_id = finished_request.request_id
self._request_streams[request_id].put(finished_request)
self.abort_request(request_id)
def add_request(self, request_id: int, **engine_add_request_kwargs) -> AsyncStream:
"""
Add a request to be sent to the engine on the next background
loop iteration.
"""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {"request_id": request_id, **engine_add_request_kwargs}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: int, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
Logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_requests(self):
"""
Get new requests from http server.
"""
new_requests: List[Dict] = []
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[int]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[int] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncInferenceEngine(InferenceEngine):
"""
Async methods for Inference Engine.
"""
async def async_step(self) -> List[str]:
"""
The async version of Engine.step()
Performs one decoding iteration and returns newly generated results.
It first schedules the sequences to be executed in the next iteration.
Then, it executes the model and updates the scheduler with the model
outputs. Finally, it decodes the sequences and returns the newly
generated results.
"""
batch = self.request_handler.schedule()
loop = asyncio.get_running_loop()
# Use run_in_executor to asyncally run the sync method model.forward().
logits = await loop.run_in_executor(
None,
self.model,
batch,
self.k_cache,
self.v_cache,
)
if self.inference_config.pad_input:
logits = logits[:, -1, :]
self.request_handler.search_tokens(self.generation_config, logits)
# Return: List[Sequence]
finished_sequences = self.request_handler.update()
return finished_sequences, self.request_handler.current_requests_in_batch() > 0
class AsyncInferenceEngine:
"""An asynchronous wrapper for LLMEngine.
This class is used to wrap the InferenceEngine class to make it asynchronous.
It uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there are
requests in the waiting queue. The generate method yields the outputs
from the InferenceEngine to the caller.
"""
_engine_class: Type[_AsyncInferenceEngine] = _AsyncInferenceEngine
def __init__(self, start_engine_loop: bool = True, **kwargs):
self.engine = self._init_engine(**kwargs)
self.background_loop = None
# reference to the unshielded loop
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property
def background_loop_status(self):
return self.background_loop is not None and not self.background_loop.done()
def start_background_loop(self):
if self.background_loop_status:
raise RuntimeError("Existing loop is running")
self._request_tracker.init_event()
self._background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish, request_tracker=self._request_tracker)
)
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, **kwargs):
return self._engine_class(**kwargs)
async def step(self):
"""
Run engine to process requests
Returns True if there are in-progress requests.
"""
new_requests = self._request_tracker.get_new_requests()
for new_request in new_requests:
self.engine.add_single_request(**new_request)
newly_finished_seqs, has_running_requests = await self.engine.async_step()
for seq in newly_finished_seqs:
self._request_tracker.process_finished_request(seq)
return has_running_requests
async def _engine_abort(self, request_ids: Iterable[int]):
self.engine.abort_request(request_ids)
async def abort(self, request_id: int):
"""
Abort a single request
"""
if not self.background_loop_status:
raise RuntimeError("Background loop is not running or launched correctly.")
return self._abort(request_id)
def _abort(self, request_id: int):
self._request_tracker.abort_request(request_id)
async def run_engine_loop(self):
processing_requests = False
while True:
if not processing_requests:
await self._request_tracker.wait_for_new_requests()
processing_requests = await self.step()
await asyncio.sleep(0)
async def add_request(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncStream:
"""
Add a request to the background tracker(waitting queue), start the background loop if needed.
"""
if not self.background_loop_status:
if self.start_engine_loop:
self.start_background_loop()
else:
raise RuntimeError("Background loop is not running.")
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
)
return stream
async def generate(
self,
request_id: int,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncIterator[str]:
"""
Generate output from a request. It receives the request from http server, adds it into the
waitting queue of Async Engine and streams the output sequence.
"""
try:
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
# If there is an exception or coroutine is cancelled, abort the
# request.
self._abort(request_id)
raise e

View File

@ -1,6 +1,6 @@
import time
from itertools import count
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union, Iterable
import numpy as np
import torch
@ -507,9 +507,9 @@ class InferenceEngine:
def generate(
self,
prompts: List[str] = None,
request_ids: Union[List[int], int] = None,
prompts: Union[List[str], str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
request_ids: List[int] = None,
return_token_ids: bool = False,
generation_config: Optional[GenerationConfig] = None,
) -> List[str]:
@ -527,6 +527,11 @@ class InferenceEngine:
List[str]: Inference result returned by one generation.
"""
with torch.inference_mode():
if isinstance(prompts, str) and isinstance(request_ids, int):
prompts = [prompts]
request_ids = [request_ids]
if prompts is not None or prompts_token_ids is not None:
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
self.add_request(
@ -535,7 +540,7 @@ class InferenceEngine:
prompts_token_ids=prompts_token_ids,
**gen_config_dict,
)
output_seqs_list = []
total_tokens_list = []
@ -580,13 +585,13 @@ class InferenceEngine:
if isinstance(prompts, (list, tuple)):
return [self.inference_config.prompt_template.format(input_text=prompt) for prompt in prompts]
elif isinstance(prompts, str):
return self.inference_config.rompt_template.format(input_text=prompts)
return self.inference_config.prompt_template.format(input_text=prompts)
else:
raise TypeError(f"Expected the input prompt to be one of list, tuple, or str, but got {type(prompts)}.")
def add_request(
self,
request_ids: List[int] = None,
request_ids: Union[List[int], int] = None,
prompts: List[str] = None,
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
**kwargs,
@ -601,6 +606,7 @@ class InferenceEngine:
"""
# apply the prompt template to the input prompts
if self.has_prompt_template and prompts is not None:
prompts = self.format_prompt(prompts)
@ -614,6 +620,7 @@ class InferenceEngine:
prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=self.inference_config.pad_input)[
"input_ids"
]
print(prompts_token_ids)
if isinstance(prompts_token_ids, list):
pass
@ -632,8 +639,6 @@ class InferenceEngine:
for i in range(prompts_num):
if request_ids:
if not isinstance(request_ids, list):
request_ids = [request_ids]
assert isinstance(
request_ids[0], int
), f"The request_id type must be int, but got {type(request_ids[0])}"
@ -734,6 +739,9 @@ class InferenceEngine:
next_tokens = self.request_handler.search_tokens(self.generation_config, logits)
self.request_handler.append_next_tokens(next_tokens)
print("in step", logits)
self.request_handler.search_tokens(self.generation_config, logits)
finished_sequences = self.request_handler.update()
return finished_sequences

View File

@ -263,24 +263,27 @@ class RequestHandler:
), f"Sequence {req.request_id} exceeds input length limit"
self.waiting_list[req.input_len * 3 // (self.inference_config.max_input_len + 1)].append(req)
def abort_sequence(self, request_id: str):
def abort_sequence(self, request_id: int):
"""
Abort the request.
"""
seq, priority = self._find_sequence(request_id)
if seq.status == RequestStatus.WAITING:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
self.done_list.remove(seq)
except:
return
result = self._find_sequence(request_id)
if result is not None:
seq, priority = result
if seq.status == RequestStatus.WAITING:
seq.mark_aborted()
self.waiting_list[priority].remove(seq)
elif seq.status.is_running():
self.running_bb.pop_seq_update_batch(seq.request_id, self.cache_manager.free_block_table)
self.running_list.remove(seq)
else:
try:
self.done_list.remove(seq)
except:
return
return
def _find_sequence(self, request_id: str) -> Sequence:
def _find_sequence(self, request_id: int) -> Sequence:
"""
Find the request by request_id.
"""
@ -324,6 +327,9 @@ class RequestHandler:
def check_unfinished_seqs(self) -> bool:
return self._has_waiting() or not self.running_list.is_empty()
def current_requests_in_batch(self) -> int:
return self.prefill_bb.current_batch_size + self.running_bb.current_batch_size
def search_tokens(self, generation_config: GenerationConfig, logits):
"""
Sample tokens for finished requests.

View File

View File

@ -0,0 +1,200 @@
"""
Doc:
Feature:
- FastAPI based http server for Colossal-Inference
- Completion Service Supported
Usage: (for local user)
- First, Lauch an API locally. `python3 -m colossalai.inference.server.api_server --model path of your llama2 model`
- Second, you can turn to the page `http://127.0.0.1:8000/docs` to check the api
- For completion service, you can invoke it by using `curl -X POST http://127.0.0.1:8000/v1/completion \
-H 'Content-Type: application/json' \
-d '{"prompt":"hello, who are you? ","stream":"False"}'`
"""
import argparse
import json
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.inference.config import InferenceConfig
from colossalai.inference.server.completion_service import CompletionServing
from colossalai.inference.server.utils import id_generator
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
supported_models_dict = {"Llama_Models": ("llama2-7b",)}
prompt_template_choices = ["llama", "vicuna"]
@app.get("/v0/models")
def get_available_models() -> Response:
return JSONResponse(supported_models_dict)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
A request should be a JSON object with the following fields:
- prompts: the prompts to use for the generation.
- stream: whether to stream the results or not.
- other fields:
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", None)
request_id = id_generator()
generation_config = get_generation_config(request_dict)
results = engine.generate(request_id, prompt, generation_config=generation_config)
# Streaming case
def stream_results():
for request_output in results:
ret = {"text": request_output}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
for request_output in results:
if request.is_disconnected():
# Abort the request if the client disconnects.
engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
ret = {"text": final_output}
return JSONResponse(ret)
@app.post("/v1/completion")
async def create_completion(request: Request):
request_dict = await request.json()
generation_config = get_generation_config(request_dict)
generator = await completion_serving.create_completion(request, generation_config)
output = tokenizer.decode(generator.output_token_id)
ret = {"request_id": generator.request_id, "text": output}
return ret
def get_generation_config(request):
generation_config = async_engine.engine.generation_config
for arg in request:
if hasattr(generation_config, arg):
generation_config[arg] = request[arg]
return generation_config
def add_engine_config(parser):
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
parser.add_argument(
"--max-model-len",
type=int,
default=None,
help="model context length. If unspecified, " "will be automatically derived from the model.",
)
# Parallel arguments
parser.add_argument(
"--worker-use-ray",
action="store_true",
help="use Ray for distributed serving, will be " "automatically set when using more than 1 GPU",
)
parser.add_argument("--pipeline-parallel-size", "-pp", type=int, default=1, help="number of pipeline stages")
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
# KV cache arguments
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
# generation arguments
parser.add_argument(
"--prompt_template",
choices=prompt_template_choices,
default=None,
help=f"Allowed choices are {','.join(prompt_template_choices)}. Default to None.",
)
# Quantization settings.
parser.add_argument(
"--quantization",
"-q",
type=str,
choices=["awq", "gptq", "squeezellm", None],
default=None,
help="Method used to quantize the weights. If "
"None, we first check the `quantization_config` "
"attribute in the model config file. If that is "
"None, we assume the model weights are not "
"quantized and use `dtype` to determine the data "
"type of the weights.",
)
parser.add_argument(
"--enforce-eager",
action="store_true",
help="Always use eager-mode PyTorch. If False, "
"will use eager mode and CUDA graph in hybrid "
"for maximal performance and flexibility.",
)
return parser
def parse_args():
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--root-path", type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy"
)
parser.add_argument(
"--model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.",
)
parser = add_engine_config(parser)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
inference_config = InferenceConfig.from_dict(vars(args))
model = AutoModelForCausalLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model)
async_engine = AsyncInferenceEngine(
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
)
engine = async_engine.engine
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
app.root_path = args.root_path
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)

View File

@ -0,0 +1,35 @@
import asyncio
from colossalai.inference.core.async_engine import AsyncInferenceEngine
from .utils import id_generator
class CompletionServing:
def __init__(self, engine: AsyncInferenceEngine, served_model: str):
self.engine = engine
self.served_model = served_model
try:
asyncio.get_running_loop()
except RuntimeError:
pass
async def create_completion(self, request, generation_config):
request_dict = await request.json()
request_id = id_generator()
prompt = request_dict.pop("prompt")
# it is not a intuitive way
self.engine.engine.generation_config = generation_config
result_generator = self.engine.generate(request_id, prompt=prompt)
final_res = None
async for res in result_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return {"error_msg": "Client disconnected"}
final_res = res
return final_res

View File

@ -0,0 +1,16 @@
# make it singleton
class NumericIDGenerator:
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(NumericIDGenerator, cls).__new__(cls)
cls._instance.current_id = 0
return cls._instance
def __call__(self):
self.current_id += 1
return self.current_id
id_generator = NumericIDGenerator()

View File

@ -164,6 +164,7 @@ class Sequence:
return (
f"(request_id={self.request_id}, "
f"prompt={self.prompt}, "
f"output_token_id={self.output_token_id},"
f"status={self.status.name}, "
f"sample_params={self.sample_params}, "
f"input_len={self.input_len},"

View File

@ -1,6 +1,7 @@
import os
from typing import Dict, List, Tuple
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
@ -36,7 +37,11 @@ class ShardFormer:
"""
def __init__(self, shard_config: ShardConfig):
self.coordinator = DistCoordinator()
self.is_distributed = dist.is_initialized()
if self.is_distributed:
self.coordinator = DistCoordinator()
else:
self.coordinator = None
self.shard_config = shard_config
def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:

View File

@ -0,0 +1,80 @@
import asyncio
from dataclasses import dataclass
import pytest
from colossalai.inference.core.async_engine import AsyncInferenceEngine
@dataclass
class SequenceTpye:
request_id: int
class MockEngine:
def __init__(self):
self.step_calls = 0
self.add_request_calls = 0
self.abort_request_calls = 0
self.request_id = None
async def async_step(self):
self.step_calls += 1
return [SequenceTpye(request_id=self.request_id)] if self.request_id else []
def generate(self, request_id):
self.request_id = request_id
def stop_generating(self):
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
class MockAsyncLLMEngine(AsyncInferenceEngine):
def _init_engine(self, *args, **kwargs):
return MockEngine()
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request(1, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request(2, "", None)
engine.engine.generate(2)
await asyncio.sleep(0)
assert engine.engine.add_request_calls == 2
assert engine.engine.step_calls == 2
await asyncio.sleep(0)
assert engine.engine.step_calls == 3
engine.engine.stop_generating()
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await asyncio.sleep(0)
assert engine.engine.step_calls == 4
await engine.add_request(3, "", None)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == 5
if __name__ == "__main__":
test_new_requests_event()

View File

@ -0,0 +1,77 @@
import pytest
from colossalai.inference.core.async_engine import RequestTracker
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not finished
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not finished
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 1 in finished
assert not new
assert stream_1.finished
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert 4 in finished
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert 2 in finished
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracker()