Testing Flagging System Implementation in Gradio Framework
This test suite validates the flagging functionality in the Gradio framework, focusing on data logging and user feedback mechanisms. The tests cover various flagging modes, file handling, and configuration options for the flagging system.
Test Coverage Overview
Implementation Analysis
Technical Details
Best Practices Demonstrated
gradio-app/gradio
test/test_flagging.py
import os
import pathlib
import tempfile
from unittest.mock import MagicMock
import pytest
import gradio as gr
from gradio import flagging
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
class TestDefaultFlagging:
def test_default_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # 2 rows written including header
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 2 # 3 rows written including header
io.close()
def test_files_saved_as_file_paths(self):
image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")}
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"image",
"image",
flagging_dir=tmpdirname,
flagging_mode="auto",
)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag([image, image])
io.close()
with open(os.path.join(tmpdirname, "dataset1.csv")) as f:
flagged_data = f.readlines()[1].split(",")[0]
assert flagged_data.endswith("bus.png")
io.close()
def test_flagging_does_not_create_unnecessary_directories(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname)
io.launch(prevent_thread_lock=True)
io.flagging_callback.flag(["test", "test"])
assert os.listdir(tmpdirname) == ["dataset1.csv"]
class TestSimpleFlagging:
def test_simple_csv_flagging_callback(self):
with tempfile.TemporaryDirectory() as tmpdirname:
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_dir=tmpdirname,
flagging_callback=flagging.SimpleCSVLogger(),
)
io.launch(prevent_thread_lock=True)
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 0 # no header in SimpleCSVLogger
row_count = io.flagging_callback.flag(["test", "test"])
assert row_count == 1 # no header in SimpleCSVLogger
io.close()
class TestDisableFlagging:
def test_flagging_no_permission_error_with_flagging_disabled(self):
tmpdirname = tempfile.mkdtemp()
os.chmod(tmpdirname, 0o444) # Make directory read-only
nonwritable_path = os.path.join(tmpdirname, "flagging_dir")
io = gr.Interface(
lambda x: x,
"text",
"text",
flagging_mode="never",
flagging_dir=nonwritable_path,
)
io.launch(prevent_thread_lock=True)
io.close()
class TestInterfaceSetsUpFlagging:
@pytest.mark.parametrize(
"flagging_mode, called",
[
("manual", True),
("auto", True),
("never", False),
],
)
def test_flag_method_init_called(self, flagging_mode, called):
flagging.FlagMethod.__init__ = MagicMock()
flagging.FlagMethod.__init__.return_value = None
gr.Interface(lambda x: x, "text", "text", flagging_mode=flagging_mode)
assert flagging.FlagMethod.__init__.called == called
@pytest.mark.parametrize(
"options, processed_options",
[
(None, [("Flag", None)]),
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]),
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]),
],
)
def test_flagging_options_processed_correctly(self, options, processed_options):
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options)
assert io.flagging_options == processed_options