Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
28 changes: 28 additions & 0 deletions test/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import shutil

import pytest
import torch


def pytest_addoption(parser):
parser.addoption(
"--use-tmp-hub-dir",
action="store_true",
help=(
"When provided, tests will use temporary directory as Torch Hub directory. "
"Downloaded models will be deleted after each test."
),
)


@pytest.fixture(scope="class")
def temp_hub_dir(tmp_path_factory, pytestconfig):
if not pytestconfig.getoption("--use-tmp-hub-dir"):
Comment on lines +18 to +20

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up in #1889 aside, I don't understand how this can work at all.

Both tmp_path_factory and pytestconfig appear nowhere else in the code base (as far as GH search and git grep go, at least), and based on the pytest docs this should probably be

Suggested change
@pytest.fixture(scope="class")
def temp_hub_dir(tmp_path_factory, pytestconfig):
if not pytestconfig.getoption("--use-tmp-hub-dir"):
@pytest.fixture(scope="class")
def temp_hub_dir(tmp_path_factory, request):
if not request.config.getoption("--use-tmp-hub-dir"):

but that still leaves me mystified as to where tmp_path_factory is coming from.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but that still leaves me mystified as to where tmp_path_factory is coming from.

Ah, just found that it's a fixture that comes with pytest by default.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, and so's pytestconfig. In any case, my suspicion now is that this only works when called from within the integration_tests folder, because (I guess) recursive conftest.pys are not evaluated when the CLI is called from a folder that's further out.

Copy link
Contributor Author

@Nayef211 Nayef211 Jan 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-vetinari yup I think your understanding is correct. Does only being able to run this test from within the integration_tests folder pose an issue for adding torchtext to conda-forge?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does only being able to run this test from within the integration_tests folder pose an issue for adding torchtext to conda-forge?

No, I just had to realise the reason, after that it's simple. ;)

yield
else:
tmp_dir = tmp_path_factory.mktemp("hub", numbered=True).resolve()
org_dir = torch.hub.get_dir()
torch.hub.set_dir(tmp_dir)
yield
torch.hub.set_dir(org_dir)
shutil.rmtree(tmp_dir, ignore_errors=True)
65 changes: 28 additions & 37 deletions test/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest # noqa: F401
import torch
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from torchtext.models import (
ROBERTA_BASE_ENCODER,
ROBERTA_LARGE_ENCODER,
Expand All @@ -10,7 +11,23 @@
from ..common.assets import get_asset_path
from ..common.torchtext_test_case import TorchtextTestCase

BUNDLERS = {
"xlmr_base": XLMR_BASE_ENCODER,
"xlmr_large": XLMR_LARGE_ENCODER,
"roberta_base": ROBERTA_BASE_ENCODER,
"roberta_large": ROBERTA_LARGE_ENCODER,
}


@parameterized_class(
("model_name",),
[
("xlmr_base",),
("xlmr_large",),
("roberta_base",),
("roberta_large",),
],
)
class TestRobertaEncoders(TorchtextTestCase):
def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text):
"""Verify pre-trained XLM-R and Roberta models in torchtext produce
Expand All @@ -31,46 +48,20 @@ def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text):
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_xlmr_base_model(self, name, is_jit):
expected_asset_name = "xlmr.base.output.pt"
test_text = "XLMR base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=XLMR_BASE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_xlmr_large_model(self, name, is_jit):
expected_asset_name = "xlmr.large.output.pt"
test_text = "XLMR base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=XLMR_LARGE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)
@parameterized.expand(["jit", "not_jit"])
def test_models(self, name):
configuration, type = self.model_name.split("_")

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_roberta_base_model(self, name, is_jit):
expected_asset_name = "roberta.base.output.pt"
test_text = "Roberta base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=ROBERTA_BASE_ENCODER,
expected_asset_name=expected_asset_name,
test_text=test_text,
)
expected_asset_name = f"{configuration}.{type}.output.pt"
is_jit = name == "jit"
if configuration == "xlmr":
test_text = "XLMR base Model Comparison"
else:
test_text = "Roberta base Model Comparison"

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_robeta_large_model(self, name, is_jit):
expected_asset_name = "roberta.large.output.pt"
test_text = "Roberta base Model Comparison"
self._roberta_encoders(
is_jit=is_jit,
encoder=ROBERTA_LARGE_ENCODER,
encoder=BUNDLERS[configuration + "_" + type],
expected_asset_name=expected_asset_name,
test_text=test_text,
)
43 changes: 37 additions & 6 deletions test/prototype/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest # noqa: F401
import torch
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from test.common.assets import get_asset_path
from test.common.parameterized_utils import nested_params
from test.common.torchtext_test_case import TorchtextTestCase
Expand Down Expand Up @@ -32,7 +33,21 @@
}


class TestT5(TorchtextTestCase):
@parameterized_class(
("model_name",),
[
("base_model",),
("base_encoder",),
("base_generation",),
("small_model",),
("small_encoder",),
("small_generation",),
("large_model",),
("large_encoder",),
("large_generation",),
],
)
class TestT5Model(TorchtextTestCase):
def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
"""Verify that pre-trained T5 models in torchtext produce
the same output as the HuggingFace reference implementation.
Expand All @@ -55,21 +70,35 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)

@nested_params(["base", "small", "large"], ["encoder", "model", "generation"], ["jit", "not_jit"])
def test_t5_encoder_model(self, configuration, type, name) -> None:
@nested_params(["jit", "not_jit"])
def test_t5_model(self, name) -> None:
configuration, type = self.model_name.split("_")

expected_asset_name = f"t5.{configuration}.{type}.output.pt"
test_text = ["Hello world", "Attention rocks!"]
is_jit = name == "jit"
t5_model = BUNDLERS[configuration + "_" + type]
self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)

@nested_params(["base", "small", "large"], ["jit", "not_jit"])
def test_t5_wrapper(self, configuration, name) -> None:

@parameterized_class(
("configuration",),
[
("small",),
("base",),
("large",),
],
)
class TestT5Wrapper(TorchtextTestCase):
@parameterized.expand(["jit", "not_jit"])
def test_t5_wrapper(self, name) -> None:
configuration = self.configuration
test_text = ["translate English to French: I want to eat pizza for dinner."]
if configuration == "small":
expected_text = ["Je veux manger la pizza pour le dîner."]
else:
expected_text = ["Je veux manger de la pizza pour le dîner."]

beam_size = 3
max_seq_len = 512
model = T5Wrapper(configuration=configuration)
Expand All @@ -79,6 +108,8 @@ def test_t5_wrapper(self, configuration, name) -> None:
output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)


class TestT5WrapperCheckpoint(TorchtextTestCase):
@parameterized.expand(["jit", "not_jit"])
def test_t5_wrapper_checkpoint(self, name) -> None:
test_text = ["translate English to French: I want to eat pizza for dinner."]
Expand Down
7 changes: 1 addition & 6 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,13 @@
import unittest
from urllib.parse import urljoin

from test.common.assets import get_asset_path
from test.common.assets import conditional_remove, get_asset_path
from torchtext import _TEXT_BUCKET
from torchtext import utils

from .common.torchtext_test_case import TorchtextTestCase


def conditional_remove(f):
if os.path.isfile(f):
os.remove(f)


class TestUtils(TorchtextTestCase):
def test_download_extract_tar(self) -> None:
# create root directory for downloading data
Expand Down