mirror of https://github.com/hpcaitech/ColossalAI
69 lines
1.8 KiB
Python
69 lines
1.8 KiB
Python
import pytest
|
|
|
|
from colossalai.inference.core.async_engine import Tracer
|
|
from colossalai.inference.struct import Sequence
|
|
|
|
|
|
class SampleEvent:
|
|
def __init__(self):
|
|
self.flag = False
|
|
|
|
def set(self):
|
|
self.flag = True
|
|
|
|
def clear(self):
|
|
self.flag = False
|
|
|
|
|
|
def test_request_tracer():
|
|
tracker = Tracer()
|
|
tracker.new_requests_event = SampleEvent()
|
|
stream_1 = tracker.add_request(1)
|
|
assert tracker.new_requests_event.flag
|
|
new = tracker.get_new_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(new) == 1
|
|
assert new[0]["request_id"] == 1
|
|
assert not stream_1.finished
|
|
|
|
stream_2 = tracker.add_request(2)
|
|
stream_3 = tracker.add_request(3)
|
|
assert tracker.new_requests_event.flag
|
|
new = tracker.get_new_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(new) == 2
|
|
assert new[0]["request_id"] == 2
|
|
assert new[1]["request_id"] == 3
|
|
assert not stream_2.finished
|
|
assert not stream_3.finished
|
|
|
|
# request_ids must be unique
|
|
with pytest.raises(KeyError):
|
|
tracker.add_request(1)
|
|
assert not tracker.new_requests_event.flag
|
|
|
|
tracker.abort_request(1)
|
|
new = tracker.get_new_requests()
|
|
assert not new
|
|
|
|
stream_4 = tracker.add_request(4)
|
|
tracker.abort_request(4)
|
|
assert tracker.new_requests_event.flag
|
|
new = tracker.get_new_requests()
|
|
assert not new
|
|
assert stream_4.finished
|
|
|
|
stream_5 = tracker.add_request(5)
|
|
assert tracker.new_requests_event.flag
|
|
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
|
|
new = tracker.get_new_requests()
|
|
assert not tracker.new_requests_event.flag
|
|
assert len(new) == 1
|
|
assert new[0]["request_id"] == 5
|
|
assert stream_2.finished
|
|
assert not stream_5.finished
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_request_tracer()
|