Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added test/asset/t5.base.encoder.output.pt
Binary file not shown.
Binary file added test/asset/t5.base.output.pt
Binary file not shown.
Empty file.
35 changes: 35 additions & 0 deletions test/prototype/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from test.common.assets import get_asset_path
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.prototype.models import (
T5_BASE_ENCODER,
T5_BASE,
)


class TestT5(TorchtextTestCase):
def _t5_model(self, t5_model, expected_asset_name, model_input):
"""Verify that pre-trained T5 models in torchtext produce
the same output as the HuggingFace reference implementation.
"""
expected_asset_path = get_asset_path(expected_asset_name)
model = t5_model.get_model()
model = model.eval()

if model.encoder_only:
actual = model(model_input)["encoder_output"]
else:
actual = model(model_input)["decoder_output"]

expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected)

def test_t5_base_encoder_model(self):
expected_asset_name = "t5.base.encoder.output.pt"
model_input = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just something to note, when we implement the transform for the model, we probably want to update the test to pass in an input string to the _t5_model method. The helper function will be responsible for applying the transform on the input string to get the tensor that can be passed into the model (code pointer). The T5Bundle class will also need to be updated to store the model transform as a member variable (code pointer).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, will keep this in mind for the next task!

self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, model_input=model_input)

def test_t5_base_model(self):
expected_asset_name = "t5.base.output.pt"
model_input = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, model_input=model_input)
Empty file.
117 changes: 117 additions & 0 deletions test/prototype/models/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from unittest.mock import patch

from test.common.torchtext_test_case import TorchtextTestCase


class TestModels(TorchtextTestCase):
def test_t5_bundler_build_model(self):
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

# case: user provide encoder checkpoint state dict
dummy_encoder_conf = T5Conf(
encoder_only=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
dummy_t5_encoder = T5Model(dummy_encoder_conf)
t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict())
self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict())

# case: user provide encoder-decoder checkpoint state dict
dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
dummy_t5 = T5Model(dummy_t5_conf)
t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict())
self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict())

@patch("logging.Logger.warning")
def test_t5_bundler_get_model(self, mock):
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-only
dummy_encoder_conf = T5Conf(
encoder_only=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
encoder_bundle = T5Bundle(dummy_encoder_conf)
encoder_bundle.get_model(load_weights=False, freeze_model=True)
mock.assert_called_with(
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
)

# encoder-decoder
dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
t5_bundle = T5Bundle(dummy_t5_conf)
t5_bundle.get_model(load_weights=False, freeze_model=True)
mock.assert_called_with(
"The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights."
)

def test_t5_bundler_raise_checkpoint(self):
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-only
with self.assertRaises(TypeError):
dummy_encoder_conf = T5Conf(
encoder_only=True,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
T5Bundle.build_model(
config=dummy_encoder_conf,
freeze_model=True,
checkpoint=1,
)

# encoder-decoder
with self.assertRaises(TypeError):
dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
T5Bundle.build_model(
config=dummy_t5_conf,
freeze_model=True,
checkpoint=1,
)

def test_t5_bundler_conf_property(self):
from torchtext.prototype.models import T5Conf, T5Bundle

dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
)
t5_bundle = T5Bundle(dummy_t5_conf)
self.assertTrue(isinstance(t5_bundle.config, T5Conf))