Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
45 changes: 44 additions & 1 deletion test/common/parameterized_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,53 @@
import json
from itertools import product

from parameterized import param
from parameterized import param, parameterized

from .assets import get_asset_path


def load_params(*paths):
with open(get_asset_path(*paths), "r") as file:
return [param(json.loads(line)) for line in file]


def _name_func(func, _, params):
strs = []
for arg in params.args:
if isinstance(arg, tuple):
strs.append("_".join(str(a) for a in arg))
else:
strs.append(str(arg))
# sanitize the test name
name = "_".join(strs).replace(".", "_")
return f"{func.__name__}_{name}"


def nested_params(*params_set):
"""Generate the cartesian product of the given list of parameters.
Args:
params_set (list of parameters): Parameters. When using ``parameterized.param`` class,
all the parameters have to be specified with the class, only using kwargs.
"""
flatten = [p for params in params_set for p in params]

# Parameters to be nested are given as list of plain objects
if all(not isinstance(p, param) for p in flatten):
args = list(product(*params_set))
return parameterized.expand(args, name_func=_name_func)

# Parameters to be nested are given as list of `parameterized.param`
if not all(isinstance(p, param) for p in flatten):
raise TypeError(
"When using ``parameterized.param``, "
"all the parameters have to be of the ``param`` type."
)
if any(p.args for p in flatten):
raise ValueError(
"When using ``parameterized.param``, "
"all the parameters have to be provided as keyword argument."
)
args = [param()]
for params in params_set:
args = [param(**x.kwargs, **y.kwargs) for x in args for y in params]
return parameterized.expand(args)
46 changes: 28 additions & 18 deletions test/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from parameterized import parameterized
from torchtext.models import (
XLMR_BASE_ENCODER,
XLMR_LARGE_ENCODER,
Expand All @@ -8,6 +7,7 @@
)

from ..common.assets import get_asset_path
from ..common.parameterized_utils import nested_params
from ..common.torchtext_test_case import TorchtextTestCase

TEST_MODELS_PARAMETERIZED_ARGS = [
Expand All @@ -27,30 +27,40 @@


class TestModels(TorchtextTestCase):
@parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS)
def test_model(self, expected_asset_name, test_text, model_bundler):
@nested_params(
[
("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER),
("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER),
(
"roberta.base.output.pt",
"Roberta base Model Comparison",
ROBERTA_BASE_ENCODER,
),
(
"roberta.large.output.pt",
"Roberta base Model Comparison",
ROBERTA_LARGE_ENCODER,
),
],
[True, False],
)
def test_model(self, model_args, is_jit):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This test is rather self-explanatory, so I think it's okay, but I recommend writing a short docstrings for tests so that future maintainers can tell what was my intention and see if the implementation makes sense.

Writing docstring helps defining what consists of a correct test, what's the definition, when there is no intuitive test criteria for a new feature.

Copy link
Contributor Author

@Nayef211 Nayef211 Jan 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. Just added a short docstring to describe what's being tested

"""Verify pre-trained XLM-R and Roberta models in torchtext produce
the same output as the reference implementation within fairseq
"""
expected_asset_name, test_text, model_bundler = model_args

expected_asset_path = get_asset_path(expected_asset_name)

transform = model_bundler.transform()
model = model_bundler.get_model()
model = model.eval()

if is_jit:
transform = torch.jit.script(transform)
model = torch.jit.script(model)

model_input = torch.tensor(transform([test_text]))
actual = model(model_input)
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

@parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS)
def test_model_jit(self, expected_asset_name, test_text, model_bundler):
expected_asset_path = get_asset_path(expected_asset_name)

transform = model_bundler.transform()
transform_jit = torch.jit.script(transform)
model = model_bundler.get_model()
model = model.eval()
model_jit = torch.jit.script(model)

model_input = torch.tensor(transform_jit([test_text]))
actual = model_jit(model_input)
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)