Back to Repositories

Testing Asynchronous Request Tracing Implementation in ColossalAI

This test suite validates the Request Tracer functionality in ColossalAI’s asynchronous inference engine. It verifies request tracking, event handling, and sequence processing for managing multiple concurrent inference requests.

Test Coverage Overview

The test coverage focuses on the core functionality of the Tracer class, which manages inference request lifecycles.

  • Request addition and tracking
  • Event flag management
  • Unique request ID validation
  • Request abortion handling
  • Finished request processing
  • Multiple concurrent request management

Implementation Analysis

The testing approach uses pytest fixtures to validate the Tracer’s state management and event handling. It implements a systematic sequence of operations to verify request tracking behavior.

The test utilizes mock events through the SampleEvent class to simulate asynchronous operations and validates the state transitions of request streams.

Technical Details

  • Testing Framework: pytest
  • Key Components: Tracer, SampleEvent, Sequence
  • Test Patterns: State verification, exception handling
  • Validation Points: Request states, event flags, request queues

Best Practices Demonstrated

The test demonstrates robust validation practices for asynchronous systems.

  • Comprehensive state verification
  • Error case handling with pytest.raises
  • Clear test flow organization
  • Isolated test scenarios
  • Thorough edge case coverage

hpcaitech/colossalai

tests/test_infer/test_async_engine/test_request_tracer.py

            
import pytest

from colossalai.inference.core.async_engine import Tracer
from colossalai.inference.struct import Sequence


class SampleEvent:
    def __init__(self):
        self.flag = False

    def set(self):
        self.flag = True

    def clear(self):
        self.flag = False


def test_request_tracer():
    tracker = Tracer()
    tracker.new_requests_event = SampleEvent()
    stream_1 = tracker.add_request(1)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 1
    assert new[0]["request_id"] == 1
    assert not stream_1.finished

    stream_2 = tracker.add_request(2)
    stream_3 = tracker.add_request(3)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 2
    assert new[0]["request_id"] == 2
    assert new[1]["request_id"] == 3
    assert not stream_2.finished
    assert not stream_3.finished

    # request_ids must be unique
    with pytest.raises(KeyError):
        tracker.add_request(1)
    assert not tracker.new_requests_event.flag

    tracker.abort_request(1)
    new = tracker.get_new_requests()
    assert not new

    stream_4 = tracker.add_request(4)
    tracker.abort_request(4)
    assert tracker.new_requests_event.flag
    new = tracker.get_new_requests()
    assert not new
    assert stream_4.finished

    stream_5 = tracker.add_request(5)
    assert tracker.new_requests_event.flag
    tracker.process_finished_request(Sequence(2, "output", [], 4, [], 0, 0))
    new = tracker.get_new_requests()
    assert not tracker.new_requests_event.flag
    assert len(new) == 1
    assert new[0]["request_id"] == 5
    assert stream_2.finished
    assert not stream_5.finished


if __name__ == "__main__":
    test_request_tracer()