Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Binary file added test/asset/t5.large.encoder.output.pt
Binary file not shown.
Binary file added test/asset/t5.large.generation.output.pt
Binary file not shown.
Binary file added test/asset/t5.large.model.output.pt
Binary file not shown.
Binary file added test/asset/t5.small.encoder.output.pt
Binary file not shown.
Binary file added test/asset/t5.small.generation.output.pt
Binary file not shown.
Binary file added test/asset/t5.small.model.output.pt
Binary file not shown.
73 changes: 44 additions & 29 deletions test/prototype/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
import torch
from parameterized import parameterized
from test.common.assets import get_asset_path
from test.common.parameterized_utils import nested_params
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.prototype.models import T5_BASE_ENCODER, T5_BASE, T5_BASE_GENERATION, T5Conf, T5Transform
from torchtext.prototype.models import (
T5_BASE_ENCODER,
T5_BASE,
T5_BASE_GENERATION,
T5_SMALL_ENCODER,
T5_SMALL,
T5_SMALL_GENERATION,
T5_LARGE_ENCODER,
T5_LARGE,
T5_LARGE_GENERATION,
T5Conf,
T5Transform,
)
from torchtext.prototype.models.t5.wrapper import T5Wrapper


BUNDLERS = {
"base_model": T5_BASE,
"base_encoder": T5_BASE_ENCODER,
"base_generation": T5_BASE_GENERATION,
"small_model": T5_SMALL,
"small_encoder": T5_SMALL_ENCODER,
"small_generation": T5_SMALL_GENERATION,
"large_model": T5_LARGE,
"large_encoder": T5_LARGE_ENCODER,
"large_generation": T5_LARGE_GENERATION,
}


class TestT5(TorchtextTestCase):
def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
"""Verify that pre-trained T5 models in torchtext produce
Expand All @@ -29,43 +55,32 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_encoder_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.encoder.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(
is_jit=is_jit, t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text
)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(is_jit=is_jit, t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_generation_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.generation.output.pt"
@nested_params(["base", "small", "large"], ["encoder", "model", "generation"], ["jit", "not_jit"])
def test_t5_encoder_model(self, configuration, type, name) -> None:
expected_asset_name = f"t5.{configuration}.{type}.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(
is_jit=is_jit, t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text
)
is_jit = name == "jit"
t5_model = BUNDLERS[configuration + "_" + type]
self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_wrapper(self, name, is_jit) -> None:
@nested_params(["base", "small", "large"], ["jit", "not_jit"])
def test_t5_wrapper(self, configuration, name) -> None:
test_text = ["translate English to French: I want to eat pizza for dinner."]
expected_text = ["Je veux manger de la pizza pour le dîner."]
if configuration == "small":
expected_text = ["Je veux manger la pizza pour le dîner."]
else:
expected_text = ["Je veux manger de la pizza pour le dîner."]
beam_size = 3
max_seq_len = 512
model = T5Wrapper(configuration="base")
if is_jit:
model = T5Wrapper(configuration=configuration)
if name == "jit":
model = torch.jit.script(model)

output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)

@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_wrapper_checkpoing(self, name, is_jit) -> None:
@parameterized.expand(["jit", "not_jit"])
def test_t5_wrapper_checkpoint(self, name) -> None:
test_text = ["translate English to French: I want to eat pizza for dinner."]
expected_text = ["Je veux manger de la pizza pour le dîner."]
beam_size = 3
Expand All @@ -84,7 +99,7 @@ def test_t5_wrapper_checkpoing(self, name, is_jit) -> None:
freeze_model=True,
strict=True,
)
if is_jit:
if name == "jit":
model = torch.jit.script(model)

output_text = model(test_text, beam_size, max_seq_len)
Expand Down
24 changes: 24 additions & 0 deletions torchtext/prototype/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@
T5_BASE_ENCODER,
T5_BASE,
T5_BASE_GENERATION,
T5_SMALL_ENCODER,
T5_SMALL,
T5_SMALL_GENERATION,
T5_LARGE_ENCODER,
T5_LARGE,
T5_LARGE_GENERATION,
T5_3B_ENCODER,
T5_3B,
T5_3B_GENERATION,
T5_11B_ENCODER,
T5_11B,
T5_11B_GENERATION,
T5Bundle,
)
from .model import T5Conf, T5Model
Expand All @@ -14,5 +26,17 @@
"T5_BASE_ENCODER",
"T5_BASE",
"T5_BASE_GENERATION",
"T5_SMALL_ENCODER",
"T5_SMALL",
"T5_SMALL_GENERATION",
"T5_LARGE_ENCODER",
"T5_LARGE",
"T5_LARGE_GENERATION",
"T5_3B_ENCODER",
"T5_3B",
"T5_3B_GENERATION",
"T5_11B_ENCODER",
"T5_11B",
"T5_11B_GENERATION",
"T5Transform",
]
Loading