diff --git a/test/integration_tests/test_models.py b/test/integration_tests/test_models.py index ee811beff1..f4c9eba687 100644 --- a/test/integration_tests/test_models.py +++ b/test/integration_tests/test_models.py @@ -1,67 +1,56 @@ import torch -import torchtext +from parameterized import parameterized +from torchtext.models import ( + XLMR_BASE_ENCODER, + XLMR_LARGE_ENCODER, + ROBERTA_BASE_ENCODER, + ROBERTA_LARGE_ENCODER, +) from ..common.assets import get_asset_path from ..common.torchtext_test_case import TorchtextTestCase +TEST_MODELS_PARAMETERIZED_ARGS = [ + ("xlmr.base.output.pt", "XLMR base Model Comparison", XLMR_BASE_ENCODER), + ("xlmr.large.output.pt", "XLMR base Model Comparison", XLMR_LARGE_ENCODER), + ( + "roberta.base.output.pt", + "Roberta base Model Comparison", + ROBERTA_BASE_ENCODER, + ), + ( + "roberta.large.output.pt", + "Roberta base Model Comparison", + ROBERTA_LARGE_ENCODER, + ), +] -class TestModels(TorchtextTestCase): - def test_roberta_base(self): - asset_path = get_asset_path("roberta.base.output.pt") - test_text = "Roberta base Model Comparison" - - roberta_base = torchtext.models.ROBERTA_BASE_ENCODER - transform = roberta_base.transform() - model = roberta_base.get_model() - model = model.eval() - model_input = torch.tensor(transform([test_text])) - actual = model(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_roberta_base_jit(self): - asset_path = get_asset_path("roberta.base.output.pt") - test_text = "Roberta base Model Comparison" - - roberta_base = torchtext.models.ROBERTA_BASE_ENCODER - transform = roberta_base.transform() - transform_jit = torch.jit.script(transform) - model = roberta_base.get_model() - model = model.eval() - model_jit = torch.jit.script(model) - - model_input = torch.tensor(transform_jit([test_text])) - actual = model_jit(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_roberta_large(self): - asset_path = get_asset_path("roberta.large.output.pt") - test_text = "Roberta base Model Comparison" +class TestModels(TorchtextTestCase): + @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) + def test_model(self, expected_asset_name, test_text, model_bundler): + expected_asset_path = get_asset_path(expected_asset_name) - roberta_large = torchtext.models.ROBERTA_LARGE_ENCODER - transform = roberta_large.transform() - model = roberta_large.get_model() + transform = model_bundler.transform() + model = model_bundler.get_model() model = model.eval() model_input = torch.tensor(transform([test_text])) actual = model(model_input) - expected = torch.load(asset_path) + expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected) - def test_roberta_large_jit(self): - asset_path = get_asset_path("roberta.large.output.pt") - test_text = "Roberta base Model Comparison" + @parameterized.expand(TEST_MODELS_PARAMETERIZED_ARGS) + def test_model_jit(self, expected_asset_name, test_text, model_bundler): + expected_asset_path = get_asset_path(expected_asset_name) - roberta_large = torchtext.models.ROBERTA_LARGE_ENCODER - transform = roberta_large.transform() + transform = model_bundler.transform() transform_jit = torch.jit.script(transform) - model = roberta_large.get_model() + model = model_bundler.get_model() model = model.eval() model_jit = torch.jit.script(model) model_input = torch.tensor(transform_jit([test_text])) actual = model_jit(model_input) - expected = torch.load(asset_path) + expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected) diff --git a/test/models/test_models.py b/test/models/test_models.py index 58942acb62..cf14984917 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -3,7 +3,6 @@ from torch.nn import functional as torch_F import copy from ..common.torchtext_test_case import TorchtextTestCase -from ..common.assets import get_asset_path class TestModules(TorchtextTestCase): @@ -37,69 +36,6 @@ def test_self_attn_mask(self): class TestModels(TorchtextTestCase): - def test_xlmr_base_output(self): - asset_name = "xlmr.base.output.pt" - asset_path = get_asset_path(asset_name) - xlmr_base = torchtext.models.XLMR_BASE_ENCODER - model = xlmr_base.get_model() - model = model.eval() - model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) - actual = model(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_xlmr_base_jit_output(self): - asset_name = "xlmr.base.output.pt" - asset_path = get_asset_path(asset_name) - xlmr_base = torchtext.models.XLMR_BASE_ENCODER - model = xlmr_base.get_model() - model = model.eval() - model_jit = torch.jit.script(model) - model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) - actual = model_jit(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_xlmr_large_output(self): - asset_name = "xlmr.large.output.pt" - asset_path = get_asset_path(asset_name) - xlmr_base = torchtext.models.XLMR_LARGE_ENCODER - model = xlmr_base.get_model() - model = model.eval() - model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) - actual = model(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_xlmr_large_jit_output(self): - asset_name = "xlmr.large.output.pt" - asset_path = get_asset_path(asset_name) - xlmr_base = torchtext.models.XLMR_LARGE_ENCODER - model = xlmr_base.get_model() - model = model.eval() - model_jit = torch.jit.script(model) - model_input = torch.tensor([[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]) - actual = model_jit(model_input) - expected = torch.load(asset_path) - torch.testing.assert_close(actual, expected) - - def test_xlmr_transform(self): - xlmr_base = torchtext.models.XLMR_BASE_ENCODER - transform = xlmr_base.transform() - test_text = "XLMR base Model Comparison" - actual = transform([test_text]) - expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] - torch.testing.assert_close(actual, expected) - - def test_xlmr_transform_jit(self): - xlmr_base = torchtext.models.XLMR_BASE_ENCODER - transform = xlmr_base.transform() - transform_jit = torch.jit.script(transform) - test_text = "XLMR base Model Comparison" - actual = transform_jit([test_text]) - expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]] - torch.testing.assert_close(actual, expected) - def test_roberta_bundler_build_model(self): from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)