Spaces:
Runtime error
Runtime error
import pytest | |
from unittest.mock import Mock, call | |
from datasets import Dataset | |
from substra_template.substra_runner import SubstraRunner | |
class TestSubstraRunner: | |
def mock_substra_client_class(self, monkeypatch): | |
mock_substra_client_class = Mock() | |
monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class) | |
return mock_substra_client_class | |
def mock_load_dataset(self, monkeypatch): | |
mock_load_dataset = Mock() | |
monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset) | |
return mock_load_dataset | |
def test_set_up_clients(self, mock_substra_client_class): | |
runner = SubstraRunner() | |
runner.set_up_clients() | |
mock_substra_client_class.assert_called() | |
def test_prepare_data(self, mock_load_dataset): | |
runner = SubstraRunner() | |
runner.prepare_data() | |
mock_load_dataset.assert_has_calls(calls=[ | |
call("mnist", split="train"), | |
call("mnist", split="test"), | |
], any_order=True) | |
assert len(runner.datasets) == runner.num_clients - 1 | |
def test_register_data(self, mock_load_dataset): | |
runner = SubstraRunner() | |
runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)] | |
runner.register_data() | |
def test_register_metric(self): | |
runner = SubstraRunner() | |
runner.set_up_clients() | |
runner.register_metric() | |
def test_set_aggregation(self): | |
pass | |
def test_set_testing(self): | |
pass | |