mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* fix test bugs
* add do sample test
* del useless lines
* fix comments
* fix tests
* delete version tag
* delete version tag
* add
* del test sever
* fix test
* fix
* Revert "add"
This reverts commit b9305fb024
.
feat/online-serving
Jianghai
7 months ago
committed by
CjhHa1
12 changed files with 98 additions and 172 deletions
@ -1,79 +0,0 @@ |
|||||||
# inspired by vLLM |
|
||||||
import subprocess |
|
||||||
import sys |
|
||||||
import time |
|
||||||
|
|
||||||
import pytest |
|
||||||
import ray |
|
||||||
import requests |
|
||||||
|
|
||||||
MAX_WAITING_TIME = 300 |
|
||||||
|
|
||||||
pytestmark = pytest.mark.asyncio |
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1) |
|
||||||
class ServerRunner: |
|
||||||
def __init__(self, args): |
|
||||||
self.proc = subprocess.Popen( |
|
||||||
["python3", "-m", "colossalai.inference.server.api_server"] + args, |
|
||||||
stdout=sys.stdout, |
|
||||||
stderr=sys.stderr, |
|
||||||
) |
|
||||||
self._wait_for_server() |
|
||||||
|
|
||||||
def ready(self): |
|
||||||
return True |
|
||||||
|
|
||||||
def _wait_for_server(self): |
|
||||||
# run health check |
|
||||||
start = time.time() |
|
||||||
while True: |
|
||||||
try: |
|
||||||
if requests.get("http://localhost:8000/v0/models").status_code == 200: |
|
||||||
break |
|
||||||
except Exception as err: |
|
||||||
if self.proc.poll() is not None: |
|
||||||
raise RuntimeError("Server exited unexpectedly.") from err |
|
||||||
|
|
||||||
time.sleep(0.5) |
|
||||||
if time.time() - start > MAX_WAITING_TIME: |
|
||||||
raise RuntimeError("Server failed to start in time.") from err |
|
||||||
|
|
||||||
def __del__(self): |
|
||||||
if hasattr(self, "proc"): |
|
||||||
self.proc.terminate() |
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session") |
|
||||||
def server(): |
|
||||||
ray.init() |
|
||||||
server_runner = ServerRunner.remote( |
|
||||||
[ |
|
||||||
"--model", |
|
||||||
"/home/chenjianghai/data/llama-7b-hf", |
|
||||||
] |
|
||||||
) |
|
||||||
ray.get(server_runner.ready.remote()) |
|
||||||
yield server_runner |
|
||||||
ray.shutdown() |
|
||||||
|
|
||||||
|
|
||||||
async def test_completion(server): |
|
||||||
data = {"prompt": "How are you?", "stream": "False"} |
|
||||||
response = await server.post("v1/completion", json=data) |
|
||||||
assert response is not None |
|
||||||
|
|
||||||
|
|
||||||
async def test_chat(server): |
|
||||||
messages = [ |
|
||||||
{"role": "system", "content": "you are a helpful assistant"}, |
|
||||||
{"role": "user", "content": "what is 1+1?"}, |
|
||||||
] |
|
||||||
data = {"messages": messages, "stream": "False"} |
|
||||||
response = await server.post("v1/chat", data) |
|
||||||
assert response is not None |
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__": |
|
||||||
pytest.main([__file__]) |
|
Loading…
Reference in new issue