diff --git a/test/integration_tests/conftest.py b/test/integration_tests/conftest.py new file mode 100644 index 0000000000..6d051420ee --- /dev/null +++ b/test/integration_tests/conftest.py @@ -0,0 +1,28 @@ +import shutil + +import pytest +import torch + + +def pytest_addoption(parser): + parser.addoption( + "--use-tmp-hub-dir", + action="store_true", + help=( + "When provided, tests will use temporary directory as Torch Hub directory. " + "Downloaded models will be deleted after each test." + ), + ) + + +@pytest.fixture(scope="class") +def temp_hub_dir(tmp_path_factory, pytestconfig): + if not pytestconfig.getoption("--use-tmp-hub-dir"): + yield + else: + tmp_dir = tmp_path_factory.mktemp("hub", numbered=True).resolve() + org_dir = torch.hub.get_dir() + torch.hub.set_dir(tmp_dir) + yield + torch.hub.set_dir(org_dir) + shutil.rmtree(tmp_dir, ignore_errors=True) diff --git a/test/integration_tests/test_models.py b/test/integration_tests/test_models.py index 8d79c69510..d140c8190a 100644 --- a/test/integration_tests/test_models.py +++ b/test/integration_tests/test_models.py @@ -1,5 +1,6 @@ +import pytest # noqa: F401 import torch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from torchtext.models import ( ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER, @@ -10,7 +11,23 @@ from ..common.assets import get_asset_path from ..common.torchtext_test_case import TorchtextTestCase +BUNDLERS = { + "xlmr_base": XLMR_BASE_ENCODER, + "xlmr_large": XLMR_LARGE_ENCODER, + "roberta_base": ROBERTA_BASE_ENCODER, + "roberta_large": ROBERTA_LARGE_ENCODER, +} + +@parameterized_class( + ("model_name",), + [ + ("xlmr_base",), + ("xlmr_large",), + ("roberta_base",), + ("roberta_large",), + ], +) class TestRobertaEncoders(TorchtextTestCase): def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text): """Verify pre-trained XLM-R and Roberta models in torchtext produce @@ -31,46 +48,20 @@ def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text): expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected) - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_xlmr_base_model(self, name, is_jit): - expected_asset_name = "xlmr.base.output.pt" - test_text = "XLMR base Model Comparison" - self._roberta_encoders( - is_jit=is_jit, - encoder=XLMR_BASE_ENCODER, - expected_asset_name=expected_asset_name, - test_text=test_text, - ) - - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_xlmr_large_model(self, name, is_jit): - expected_asset_name = "xlmr.large.output.pt" - test_text = "XLMR base Model Comparison" - self._roberta_encoders( - is_jit=is_jit, - encoder=XLMR_LARGE_ENCODER, - expected_asset_name=expected_asset_name, - test_text=test_text, - ) + @parameterized.expand(["jit", "not_jit"]) + def test_models(self, name): + configuration, type = self.model_name.split("_") - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_roberta_base_model(self, name, is_jit): - expected_asset_name = "roberta.base.output.pt" - test_text = "Roberta base Model Comparison" - self._roberta_encoders( - is_jit=is_jit, - encoder=ROBERTA_BASE_ENCODER, - expected_asset_name=expected_asset_name, - test_text=test_text, - ) + expected_asset_name = f"{configuration}.{type}.output.pt" + is_jit = name == "jit" + if configuration == "xlmr": + test_text = "XLMR base Model Comparison" + else: + test_text = "Roberta base Model Comparison" - @parameterized.expand([("jit", True), ("not_jit", False)]) - def test_robeta_large_model(self, name, is_jit): - expected_asset_name = "roberta.large.output.pt" - test_text = "Roberta base Model Comparison" self._roberta_encoders( is_jit=is_jit, - encoder=ROBERTA_LARGE_ENCODER, + encoder=BUNDLERS[configuration + "_" + type], expected_asset_name=expected_asset_name, test_text=test_text, ) diff --git a/test/prototype/integration_tests/test_models.py b/test/prototype/integration_tests/test_models.py index 378a95711c..4130d67aa4 100644 --- a/test/prototype/integration_tests/test_models.py +++ b/test/prototype/integration_tests/test_models.py @@ -1,5 +1,6 @@ +import pytest # noqa: F401 import torch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class from test.common.assets import get_asset_path from test.common.parameterized_utils import nested_params from test.common.torchtext_test_case import TorchtextTestCase @@ -32,7 +33,21 @@ } -class TestT5(TorchtextTestCase): +@parameterized_class( + ("model_name",), + [ + ("base_model",), + ("base_encoder",), + ("base_generation",), + ("small_model",), + ("small_encoder",), + ("small_generation",), + ("large_model",), + ("large_encoder",), + ("large_generation",), + ], +) +class TestT5Model(TorchtextTestCase): def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): """Verify that pre-trained T5 models in torchtext produce the same output as the HuggingFace reference implementation. @@ -55,21 +70,35 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06) - @nested_params(["base", "small", "large"], ["encoder", "model", "generation"], ["jit", "not_jit"]) - def test_t5_encoder_model(self, configuration, type, name) -> None: + @nested_params(["jit", "not_jit"]) + def test_t5_model(self, name) -> None: + configuration, type = self.model_name.split("_") + expected_asset_name = f"t5.{configuration}.{type}.output.pt" test_text = ["Hello world", "Attention rocks!"] is_jit = name == "jit" t5_model = BUNDLERS[configuration + "_" + type] self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text) - @nested_params(["base", "small", "large"], ["jit", "not_jit"]) - def test_t5_wrapper(self, configuration, name) -> None: + +@parameterized_class( + ("configuration",), + [ + ("small",), + ("base",), + ("large",), + ], +) +class TestT5Wrapper(TorchtextTestCase): + @parameterized.expand(["jit", "not_jit"]) + def test_t5_wrapper(self, name) -> None: + configuration = self.configuration test_text = ["translate English to French: I want to eat pizza for dinner."] if configuration == "small": expected_text = ["Je veux manger la pizza pour le dîner."] else: expected_text = ["Je veux manger de la pizza pour le dîner."] + beam_size = 3 max_seq_len = 512 model = T5Wrapper(configuration=configuration) @@ -79,6 +108,8 @@ def test_t5_wrapper(self, configuration, name) -> None: output_text = model(test_text, beam_size, max_seq_len) self.assertEqual(output_text, expected_text) + +class TestT5WrapperCheckpoint(TorchtextTestCase): @parameterized.expand(["jit", "not_jit"]) def test_t5_wrapper_checkpoint(self, name) -> None: test_text = ["translate English to French: I want to eat pizza for dinner."] diff --git a/test/test_utils.py b/test/test_utils.py index 3262cc0dc3..c28299dc82 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -5,18 +5,13 @@ import unittest from urllib.parse import urljoin -from test.common.assets import get_asset_path +from test.common.assets import conditional_remove, get_asset_path from torchtext import _TEXT_BUCKET from torchtext import utils from .common.torchtext_test_case import TorchtextTestCase -def conditional_remove(f): - if os.path.isfile(f): - os.remove(f) - - class TestUtils(TorchtextTestCase): def test_download_extract_tar(self) -> None: # create root directory for downloading data