Back to Repositories

Testing Flagging System Implementation in Gradio Framework

This test suite validates the flagging functionality in the Gradio framework, focusing on data logging and user feedback mechanisms. The tests cover various flagging modes, file handling, and configuration options for the flagging system.

Test Coverage Overview

The test suite provides comprehensive coverage of Gradio’s flagging system, including:
  • Default flagging behavior and row counting
  • File path handling for image inputs
  • Directory creation and management
  • Simple CSV logging implementation
  • Permission handling and disabled flagging scenarios
  • Flagging mode initialization and options processing

Implementation Analysis

The testing approach utilizes pytest’s parametrize feature for multiple test scenarios and mock objects for isolation. Tests employ temporary directories for file operations and validate both synchronous and asynchronous flagging behaviors.

Key patterns include systematic validation of file system operations, mock implementations for permission testing, and parametrized test cases for different flagging modes.

Technical Details

Testing tools and configuration:
  • pytest framework for test organization
  • unittest.mock for mocking objects
  • tempfile for temporary test directories
  • pathlib for cross-platform path handling
  • Environment variable configuration for analytics

Best Practices Demonstrated

The test suite demonstrates several testing best practices:
  • Proper resource cleanup using context managers
  • Isolation of filesystem operations
  • Comprehensive edge case handling
  • Parametrized testing for multiple scenarios
  • Clear test class organization by functionality

gradio-app/gradio

test/test_flagging.py

            
import os
import pathlib
import tempfile
from unittest.mock import MagicMock

import pytest

import gradio as gr
from gradio import flagging

os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"


class TestDefaultFlagging:
    def test_default_flagging_callback(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
            io.launch(prevent_thread_lock=True)
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 1  # 2 rows written including header
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 2  # 3 rows written including header
        io.close()

    def test_files_saved_as_file_paths(self):
        image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(
                lambda x: x,
                "image",
                "image",
                flagging_dir=tmpdirname,
                flagging_mode="auto",
            )
            io.launch(prevent_thread_lock=True)
            io.flagging_callback.flag([image, image])
            io.close()
            with open(os.path.join(tmpdirname, "dataset1.csv")) as f:
                flagged_data = f.readlines()[1].split(",")[0]
                assert flagged_data.endswith("bus.png")
        io.close()

    def test_flagging_does_not_create_unnecessary_directories(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
            io.launch(prevent_thread_lock=True)
            io.flagging_callback.flag(["test", "test"])
            assert os.listdir(tmpdirname) == ["dataset1.csv"]


class TestSimpleFlagging:
    def test_simple_csv_flagging_callback(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            io = gr.Interface(
                lambda x: x,
                "text",
                "text",
                flagging_dir=tmpdirname,
                flagging_callback=flagging.SimpleCSVLogger(),
            )
            io.launch(prevent_thread_lock=True)
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 0  # no header in SimpleCSVLogger
            row_count = io.flagging_callback.flag(["test", "test"])
            assert row_count == 1  # no header in SimpleCSVLogger
        io.close()


class TestDisableFlagging:
    def test_flagging_no_permission_error_with_flagging_disabled(self):
        tmpdirname = tempfile.mkdtemp()
        os.chmod(tmpdirname, 0o444)  # Make directory read-only
        nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
        io = gr.Interface(
            lambda x: x,
            "text",
            "text",
            flagging_mode="never",
            flagging_dir=nonwritable_path,
        )
        io.launch(prevent_thread_lock=True)
        io.close()


class TestInterfaceSetsUpFlagging:
    @pytest.mark.parametrize(
        "flagging_mode, called",
        [
            ("manual", True),
            ("auto", True),
            ("never", False),
        ],
    )
    def test_flag_method_init_called(self, flagging_mode, called):
        flagging.FlagMethod.__init__ = MagicMock()
        flagging.FlagMethod.__init__.return_value = None
        gr.Interface(lambda x: x, "text", "text", flagging_mode=flagging_mode)
        assert flagging.FlagMethod.__init__.called == called

    @pytest.mark.parametrize(
        "options, processed_options",
        [
            (None, [("Flag", None)]),
            (["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
            ([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
        ],
    )
    def test_flagging_options_processed_correctly(self, options, processed_options):
        io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
        assert io.flagging_options == processed_options