Back to Repositories

Testing Queue Management and Concurrency Control in gradio-app/gradio

A comprehensive test suite for validating queueing functionality in Gradio applications, focusing on request handling, concurrency limits, and status message management. The suite ensures reliable asynchronous processing and proper queue management across different concurrent operations.

Test Coverage Overview

The test suite covers essential queueing mechanisms in Gradio:
  • Single request processing and response validation
  • Queue status monitoring and size tracking
  • Concurrent request handling with various concurrency limits
  • Multiple operation types with shared concurrency IDs

Implementation Analysis

Tests utilize pytest fixtures and async patterns to validate queue behavior. The implementation leverages FastAPI TestClient for API validation and gradio_client for simulating real-world request scenarios, with careful timing controls via time.sleep() to ensure predictable concurrent states.

Technical Details

Key technical components:
  • pytest for test framework
  • FastAPI TestClient for API testing
  • gradio_client for client-side operations
  • concurrent.futures for managing parallel executions
  • Custom fixtures for connection management

Best Practices Demonstrated

The test suite exemplifies robust testing practices:
  • Parametrized testing for different concurrency scenarios
  • Comprehensive status checking and validation
  • Proper test isolation and cleanup
  • Clear separation of concerns between different queue operations

gradio-app/gradio

test/test_queueing.py

            
import time
from concurrent.futures import wait

import gradio_client as grc
import pytest
from fastapi.testclient import TestClient

import gradio as gr
from gradio.route_utils import API_PREFIX


class TestQueueing:
    def test_single_request(self, connect):
        with gr.Blocks() as demo:
            name = gr.Textbox()
            output = gr.Textbox()

            def greet(x):
                return f"Hello, {x}!"

            name.submit(greet, name, output)

        with connect(demo) as client:
            job = client.submit("x", fn_index=0)
            assert job.result() == "Hello, x!"

    def test_all_status_messages(self, connect):
        with gr.Blocks() as demo:
            name = gr.Textbox()
            output = gr.Textbox()

            def greet(x):
                time.sleep(2)
                return f"Hello, {x}!"

            name.submit(greet, name, output, concurrency_limit=2)

        app, local_url, _ = demo.launch(prevent_thread_lock=True)
        test_client = TestClient(app)
        client = grc.Client(local_url)

        client.submit("a", fn_index=0)
        job2 = client.submit("b", fn_index=0)
        client.submit("c", fn_index=0)
        job4 = client.submit("d", fn_index=0)

        sizes = []
        while job4.status().code.value != "FINISHED":
            queue_status = test_client.get(f"{API_PREFIX}/queue/status").json()
            queue_size = queue_status["queue_size"]
            if len(sizes) == 0 or queue_size != sizes[-1]:
                sizes.append(queue_size)
            time.sleep(0.01)

        time.sleep(0.1)
        queue_status = test_client.get(f"{API_PREFIX}/queue/status").json()
        queue_size = queue_status["queue_size"]
        if queue_size != sizes[-1]:
            sizes.append(queue_size)

        assert (
            max(sizes)
            in [
                2,
                3,
                4,
            ]
        )  # Can be 2 - 4, depending on if the workers have picked up jobs before the queue status is checked

        assert min(sizes) == 0
        assert sizes[-1] == 0

        assert job2.result() == "Hello, b!"
        assert job4.result() == "Hello, d!"

    @pytest.mark.parametrize(
        "default_concurrency_limit, statuses",
        [
            ("not_set", ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]),
            (None, ["PROCESSING", "PROCESSING", "PROCESSING"]),
            (1, ["IN_QUEUE", "IN_QUEUE", "PROCESSING"]),
            (2, ["IN_QUEUE", "PROCESSING", "PROCESSING"]),
        ],
    )
    def test_default_concurrency_limits(self, default_concurrency_limit, statuses):
        with gr.Blocks() as demo:
            a = gr.Number()
            b = gr.Number()
            output = gr.Number()

            add_btn = gr.Button("Add")

            @add_btn.click(inputs=[a, b], outputs=output)
            def add(x, y):
                time.sleep(4)
                return x + y

        demo.queue(default_concurrency_limit=default_concurrency_limit)
        _, local_url, _ = demo.launch(
            prevent_thread_lock=True,
        )
        client = grc.Client(local_url)

        add_job_1 = client.submit(1, 1, fn_index=0)
        add_job_2 = client.submit(1, 1, fn_index=0)
        add_job_3 = client.submit(1, 1, fn_index=0)

        time.sleep(2)

        add_job_statuses = [add_job_1.status(), add_job_2.status(), add_job_3.status()]
        assert sorted([s.code.value for s in add_job_statuses]) == statuses

    def test_concurrency_limits(self, connect):
        with gr.Blocks() as demo:
            a = gr.Number()
            b = gr.Number()
            output = gr.Number()

            add_btn = gr.Button("Add")

            @add_btn.click(inputs=[a, b], outputs=output, concurrency_limit=2)
            def add(x, y):
                time.sleep(2)
                return x + y

            sub_btn = gr.Button("Subtract")

            @sub_btn.click(inputs=[a, b], outputs=output, concurrency_limit=None)
            def sub(x, y):
                time.sleep(2)
                return x - y

            mul_btn = gr.Button("Multiply")

            @mul_btn.click(
                inputs=[a, b],
                outputs=output,
                concurrency_limit=2,
                concurrency_id="muldiv",
            )
            def mul(x, y):
                time.sleep(2)
                return x * y

            div_btn = gr.Button("Divide")

            @div_btn.click(
                inputs=[a, b],
                outputs=output,
                concurrency_limit=2,
                concurrency_id="muldiv",
            )
            def div(x, y):
                time.sleep(2)
                return x / y

        with connect(demo) as client:
            add_job_1 = client.submit(1, 1, fn_index=0)
            add_job_2 = client.submit(1, 1, fn_index=0)
            add_job_3 = client.submit(1, 1, fn_index=0)
            sub_job_1 = client.submit(1, 1, fn_index=1)
            sub_job_2 = client.submit(1, 1, fn_index=1)
            sub_job_3 = client.submit(1, 1, fn_index=1)
            mul_job_1 = client.submit(1, 1, fn_index=2)
            div_job_1 = client.submit(1, 1, fn_index=3)
            mul_job_2 = client.submit(1, 1, fn_index=2)

            time.sleep(2)

            add_job_statuses = [
                add_job_1.status(),
                add_job_2.status(),
                add_job_3.status(),
            ]
            assert sorted([s.code.value for s in add_job_statuses]) == [
                "IN_QUEUE",
                "PROCESSING",
                "PROCESSING",
            ]

            sub_job_statuses = [
                sub_job_1.status(),
                sub_job_2.status(),
                sub_job_3.status(),
            ]
            assert [s.code.value for s in sub_job_statuses] == [
                "PROCESSING",
                "PROCESSING",
                "PROCESSING",
            ]

            muldiv_job_statuses = [
                mul_job_1.status(),
                div_job_1.status(),
                mul_job_2.status(),
            ]
            assert sorted([s.code.value for s in muldiv_job_statuses]) == [
                "IN_QUEUE",
                "PROCESSING",
                "PROCESSING",
            ]
            wait(
                [
                    add_job_1,
                    add_job_2,
                    add_job_3,
                    sub_job_1,
                    sub_job_2,
                    sub_job_3,
                    sub_job_3,
                    mul_job_1,
                    div_job_1,
                    mul_job_2,
                ]
            )