mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.1 KiB
81 lines
2.1 KiB
import asyncio
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
|
|
from colossalai.inference.core.async_engine import AsyncInferenceEngine
|
|
|
|
|
|
@dataclass
|
|
class MockSequence:
|
|
request_id: int
|
|
|
|
|
|
class MockEngine:
|
|
def __init__(self):
|
|
self.step_calls = 0
|
|
self.add_request_calls = 0
|
|
self.abort_request_calls = 0
|
|
self.request_id = None
|
|
|
|
async def async_step(self):
|
|
self.step_calls += 1
|
|
return ([MockSequence(request_id=self.request_id)], True) if self.request_id else ([], False)
|
|
|
|
def add_single_request(self, **kwargs):
|
|
del kwargs
|
|
self.add_request_calls += 1
|
|
|
|
def generate(self, request_id):
|
|
self.request_id = request_id
|
|
|
|
def stop_generating(self):
|
|
self.request_id = None
|
|
|
|
def add_request(self, **kwargs):
|
|
del kwargs # Unused
|
|
self.add_request_calls += 1
|
|
|
|
def abort_request(self, request_id):
|
|
del request_id # Unused
|
|
self.abort_request_calls += 1
|
|
|
|
|
|
class MockAsyncInferenceEngine(AsyncInferenceEngine):
|
|
def _init_engine(self, *args, **kwargs):
|
|
return MockEngine()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_new_requests_event():
|
|
engine = MockAsyncInferenceEngine()
|
|
engine.start_background_loop()
|
|
await asyncio.sleep(0.01)
|
|
assert engine.engine.step_calls == 0
|
|
|
|
await engine.add_request(1, "", None)
|
|
await asyncio.sleep(0.01)
|
|
assert engine.engine.add_request_calls == 1
|
|
assert engine.engine.step_calls == 1
|
|
|
|
await engine.add_request(2, "", None)
|
|
engine.engine.generate(2)
|
|
await asyncio.sleep(0)
|
|
assert engine.engine.add_request_calls == 2
|
|
assert engine.engine.step_calls == 2
|
|
await asyncio.sleep(0)
|
|
assert engine.engine.step_calls == 3
|
|
engine.engine.stop_generating()
|
|
await asyncio.sleep(0)
|
|
assert engine.engine.step_calls == 4
|
|
await asyncio.sleep(0)
|
|
assert engine.engine.step_calls == 4
|
|
|
|
await engine.add_request(3, "", None)
|
|
await asyncio.sleep(0.01)
|
|
assert engine.engine.add_request_calls == 3
|
|
assert engine.engine.step_calls == 5
|
|
await asyncio.sleep(0.01)
|
|
assert engine.engine.add_request_calls == 3
|
|
assert engine.engine.step_calls == 5
|