import pytest

from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence


class SampleEvent:
    def __init__(self):
        self.flag = False

    def set(self):
        self.flag = True

    def clear(self):
        self.flag = False


def test_request_tracer():
    tracker = Tracer()
    tracker.new_requests_event = SampleEvent()
    stream_1 = tracker.add_request(1)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 1
    assert new[0]["request_id"] == 1
    assert not stream_1.finished

    stream_2 = tracker.add_request(2)
    stream_3 = tracker.add_request(3)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 2
    assert new[0]["request_id"] == 2
    assert new[1]["request_id"] == 3
    assert not stream_2.finished
    assert not stream_3.finished

    # request_ids must be unique
    with pytest.raises(KeyError):
        tracker.add_request(1)
    assert not tracker.new_requests_event.flag

    tracker.abort_request(1)
    new = tracker.get_new_requests()
    assert not new

    stream_4 = tracker.add_request(4)
    tracker.abort_request(4)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not new
    assert stream_4.finished

    stream_5 = tracker.add_request(5)
    assert tracker.new_requests_event.flag
    tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 1
    assert new[0]["request_id"] == 5
    assert stream_2.finished
    assert not stream_5.finished


if __name__ == "__main__":
    test_request_tracer()