ColossalAI/tests/test_infer/test_async_engine/test_request_tracer.py

69 lines
1.8 KiB
Python

import pytest
from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence
class SampleEvent:
def __init__(self):
self.flag = False
def set(self):
self.flag = True
def clear(self):
self.flag = False
def test_request_tracer():
tracker = Tracer()
tracker.new_requests_event = SampleEvent()
stream_1 = tracker.add_request(1)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 1
assert not stream_1.finished
stream_2 = tracker.add_request(2)
stream_3 = tracker.add_request(3)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == 2
assert new[1]["request_id"] == 3
assert not stream_2.finished
assert not stream_3.finished
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request(1)
assert not tracker.new_requests_event.flag
tracker.abort_request(1)
new = tracker.get_new_requests()
assert not new
stream_4 = tracker.add_request(4)
tracker.abort_request(4)
assert tracker.new_requests_event.flag
new = tracker.get_new_requests()
assert not new
assert stream_4.finished
stream_5 = tracker.add_request(5)
assert tracker.new_requests_event.flag
tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
new = tracker.get_new_requests()
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == 5
assert stream_2.finished
assert not stream_5.finished
if __name__ == "__main__":
test_request_tracer()