|
import os |
|
import pathlib |
|
import tempfile |
|
from unittest.mock import MagicMock |
|
|
|
import pytest |
|
|
|
import gradio as gr |
|
from gradio import flagging |
|
|
|
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" |
|
|
|
|
|
class TestDefaultFlagging: |
|
def test_default_flagging_callback(self): |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname) |
|
io.launch(prevent_thread_lock=True) |
|
row_count = io.flagging_callback.flag(["test", "test"]) |
|
assert row_count == 1 |
|
row_count = io.flagging_callback.flag(["test", "test"]) |
|
assert row_count == 2 |
|
io.close() |
|
|
|
def test_files_saved_as_file_paths(self): |
|
image = {"path": str(pathlib.Path(__file__).parent / "test_files" / "bus.png")} |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
io = gr.Interface( |
|
lambda x: x, |
|
"image", |
|
"image", |
|
flagging_dir=tmpdirname, |
|
flagging_mode="auto", |
|
) |
|
io.launch(prevent_thread_lock=True) |
|
io.flagging_callback.flag([image, image]) |
|
io.close() |
|
with open(os.path.join(tmpdirname, "dataset1.csv")) as f: |
|
flagged_data = f.readlines()[1].split(",")[0] |
|
assert flagged_data.endswith("bus.png") |
|
io.close() |
|
|
|
def test_flagging_does_not_create_unnecessary_directories(self): |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
io = gr.Interface(lambda x: x, "text", "text", flagging_dir=tmpdirname) |
|
io.launch(prevent_thread_lock=True) |
|
io.flagging_callback.flag(["test", "test"]) |
|
assert os.listdir(tmpdirname) == ["dataset1.csv"] |
|
|
|
|
|
class TestSimpleFlagging: |
|
def test_simple_csv_flagging_callback(self): |
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
io = gr.Interface( |
|
lambda x: x, |
|
"text", |
|
"text", |
|
flagging_dir=tmpdirname, |
|
flagging_callback=flagging.SimpleCSVLogger(), |
|
) |
|
io.launch(prevent_thread_lock=True) |
|
row_count = io.flagging_callback.flag(["test", "test"]) |
|
assert row_count == 0 |
|
row_count = io.flagging_callback.flag(["test", "test"]) |
|
assert row_count == 1 |
|
io.close() |
|
|
|
|
|
class TestDisableFlagging: |
|
def test_flagging_no_permission_error_with_flagging_disabled(self): |
|
tmpdirname = tempfile.mkdtemp() |
|
os.chmod(tmpdirname, 0o444) |
|
nonwritable_path = os.path.join(tmpdirname, "flagging_dir") |
|
io = gr.Interface( |
|
lambda x: x, |
|
"text", |
|
"text", |
|
flagging_mode="never", |
|
flagging_dir=nonwritable_path, |
|
) |
|
io.launch(prevent_thread_lock=True) |
|
io.close() |
|
|
|
|
|
class TestInterfaceSetsUpFlagging: |
|
@pytest.mark.parametrize( |
|
"flagging_mode, called", |
|
[ |
|
("manual", True), |
|
("auto", True), |
|
("never", False), |
|
], |
|
) |
|
def test_flag_method_init_called(self, flagging_mode, called): |
|
flagging.FlagMethod.__init__ = MagicMock() |
|
flagging.FlagMethod.__init__.return_value = None |
|
gr.Interface(lambda x: x, "text", "text", flagging_mode=flagging_mode) |
|
assert flagging.FlagMethod.__init__.called == called |
|
|
|
@pytest.mark.parametrize( |
|
"options, processed_options", |
|
[ |
|
(None, [("Flag", None)]), |
|
(["yes", "no"], [("Flag as yes", "yes"), ("Flag as no", "no")]), |
|
([("abc", "de"), ("123", "45")], [("abc", "de"), ("123", "45")]), |
|
], |
|
) |
|
def test_flagging_options_processed_correctly(self, options, processed_options): |
|
io = gr.Interface(lambda x: x, "text", "text", flagging_options=options) |
|
assert io.flagging_options == processed_options |
|
|