ColossalAI/tests/test_infer/test_server.py

80 lines
2.0 KiB
Python

# inspired by vLLM
import subprocess
import sys
import time
import pytest
import ray
import requests
MAX_WAITING_TIME = 300
pytestmark = pytest.mark.asyncio
@ray.remote(num_gpus=1)
class ServerRunner:
def __init__(self, args):
self.proc = subprocess.Popen(
["python3", "-m", "colossalai.inference.server.api_server"] + args,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server()
def ready(self):
return True
def _wait_for_server(self):
# run health check
start = time.time()
while True:
try:
if requests.get("http://localhost:8000/v0/models").status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > MAX_WAITING_TIME:
raise RuntimeError("Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
@pytest.fixture(scope="session")
def server():
ray.init()
server_runner = ServerRunner.remote(
[
"--model",
"/home/chenjianghai/data/llama-7b-hf",
]
)
ray.get(server_runner.ready.remote())
yield server_runner
ray.shutdown()
async def test_completion(server):
data = {"prompt": "How are you?", "stream": "False"}
response = await server.post("v1/completion", json=data)
assert response is not None
async def test_chat(server):
messages = [
{"role": "system", "content": "you are a helpful assistant"},
{"role": "user", "content": "what is 1+1?"},
]
data = {"messages": messages, "stream": "False"}
response = await server.post("v1/chat", data)
assert response is not None
if __name__ == "__main__":
pytest.main([__file__])