diff --git a/test/common/parameterized_utils.py b/test/common/parameterized_utils.py index 5094579501..c97c8bdb4f 100644 --- a/test/common/parameterized_utils.py +++ b/test/common/parameterized_utils.py @@ -1,6 +1,7 @@ import json +from itertools import product -from parameterized import param +from parameterized import param, parameterized from .assets import get_asset_path @@ -8,3 +9,45 @@ def load_params(*paths): with open(get_asset_path(*paths), "r") as file: return [param(json.loads(line)) for line in file] + + +def _name_func(func, _, params): + strs = [] + for arg in params.args: + if isinstance(arg, tuple): + strs.append("_".join(str(a) for a in arg)) + else: + strs.append(str(arg)) + # sanitize the test name + name = "_".join(strs).replace(".", "_") + return f"{func.__name__}_{name}" + + +def nested_params(*params_set): + """Generate the cartesian product of the given list of parameters. + Args: + params_set (list of parameters): Parameters. When using ``parameterized.param`` class, + all the parameters have to be specified with the class, only using kwargs. + """ + flatten = [p for params in params_set for p in params] + + # Parameters to be nested are given as list of plain objects + if all(not isinstance(p, param) for p in flatten): + args = list(product(*params_set)) + return parameterized.expand(args, name_func=_name_func) + + # Parameters to be nested are given as list of `parameterized.param` + if not all(isinstance(p, param) for p in flatten): + raise TypeError( + "When using ``parameterized.param``, " + "all the parameters have to be of the ``param`` type." + ) + if any(p.args for p in flatten): + raise ValueError( + "When using ``parameterized.param``, " + "all the parameters have to be provided as keyword argument." + ) + args = [param()] + for params in params_set: + args = [param(**x.kwargs, **y.kwargs) for x in args for y in params] + return parameterized.expand(args) diff --git a/test/integration_tests/test_models.py b/test/integration_tests/test_models.py index f4c9eba687..67e10b0ee8 100644 --- a/test/integration_tests/test_models.py +++ b/test/integration_tests/test_models.py @@ -1,5 +1,4 @@ import torch -from parameterized import parameterized from torchtext.models import ( XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER, @@ -8,6 +7,7 @@ ) from ..common.assets import get_asset_path +from ..common.parameterized_utils import nested_params from ..common.torchtext_test_case import TorchtextTestCase TEST_MODELS_PARAMETERIZED_ARGS = [ @@ -27,30 +27,40 @@ class TestModels(TorchtextTestCase): - @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) - def test_model(self, expected_asset_name, test_text, model_bundler): + @nested_params( + [ + ("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER), + ("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER), + ( + "roberta.base.output.pt", + "Roberta base Model Comparison", + ROBERTA_BASE_ENCODER, + ), + ( + "roberta.large.output.pt", + "Roberta base Model Comparison", + ROBERTA_LARGE_ENCODER, + ), + ], + [True, False], + ) + def test_model(self, model_args, is_jit): + """Verify pre-trained XLM-R and Roberta models in torchtext produce + the same output as the reference implementation within fairseq + """ + expected_asset_name, test_text, model_bundler = model_args + expected_asset_path = get_asset_path(expected_asset_name) transform = model_bundler.transform() model = model_bundler.get_model() model = model.eval() + if is_jit: + transform = torch.jit.script(transform) + model = torch.jit.script(model) + model_input = torch.tensor(transform([test_text])) actual = model(model_input) expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected) - - @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) - def test_model_jit(self, expected_asset_name, test_text, model_bundler): - expected_asset_path = get_asset_path(expected_asset_name) - - transform = model_bundler.transform() - transform_jit = torch.jit.script(transform) - model = model_bundler.get_model() - model = model.eval() - model_jit = torch.jit.script(model) - - model_input = torch.tensor(transform_jit([test_text])) - actual = model_jit(model_input) - expected = torch.load(expected_asset_path) - torch.testing.assert_close(actual, expected)