Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ jobs:
python3 -m pip --quiet install sentencepiece
python3 -m pip --quiet install tqdm
python3 -m pip --quiet install expecttest
python3 -m pip --quiet install transformers
# Run Tests
python3 -m torch.utils.collect_env
cd test
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ The library currently consist of following pre-trained models:
* `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_
* XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_
* T5: `Small, Base, Large, 3B, and 11B Architecture <https://github.com/google-research/text-to-text-transfer-transformer>`_
* Flan-T5: `Small, Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_
* Flan-T5: `Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_

Tokenizers
==========
Expand Down
184 changes: 75 additions & 109 deletions test/integration_tests/test_t5_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import tempfile

import pytest # noqa: F401
import torch
from parameterized import parameterized_class
from torchtext.models import T5Bundle
from torchtext import _TEXT_BUCKET
from torchtext._download_hooks import _TEST_DOWNLOAD_MANAGER
from torchtext.models import (
FLAN_T5_BASE,
FLAN_T5_BASE_ENCODER,
FLAN_T5_BASE_GENERATION,
T5_BASE,
T5_BASE_ENCODER,
T5_BASE_GENERATION,
Expand All @@ -14,11 +19,11 @@
T5_SMALL,
T5_SMALL_ENCODER,
T5_SMALL_GENERATION,
T5Bundle,
)
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.parameterized_utils import nested_params
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
from transformers import T5EncoderModel, T5ForConditionalGeneration, T5Model

BUNDLERS = {
"base_model": T5_BASE,
Expand All @@ -30,6 +35,9 @@
"large_model": T5_LARGE,
"large_encoder": T5_LARGE_ENCODER,
"large_generation": T5_LARGE_GENERATION,
"flan_base_encoder": FLAN_T5_BASE_ENCODER,
"flan_base_model": FLAN_T5_BASE,
"flan_base_generation": FLAN_T5_BASE_GENERATION,
}


Expand All @@ -45,6 +53,9 @@
("large_model",),
("large_encoder",),
("large_generation",),
("flan_base_encoder",),
("flan_base_model",),
("flan_base_generation",),
],
)
class TestT5Model(TorchtextTestCase):
Expand Down Expand Up @@ -74,126 +85,81 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):

def _t5_get_encoder(self, model, model_input, encoder_output):
encoder = model.get_encoder()
# Need to set the tgt_key_padding_mask to ensure the same results
# Need to set the key_padding_mask to ensure the same results
encoder_padding_mask = model_input.eq(model.padding_idx)
output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"]
assert torch.all(output_from_get_encoder.eq(encoder_output))

@nested_params(["jit", "not_jit"])
@nested_params(["not_jit", "jit"])
def test_t5_model(self, name) -> None:
configuration, type = self.model_name.split("_")
names = self.model_name.split("_")

num_names = len(names)

if num_names == 3:
# Handled slightly differently for Flan-T5 model naming
configuration = names[1]
type = names[2]
expected_asset_name = f"t5.flan.{configuration}.{type}.output.pt"
t5_model = BUNDLERS["flan_" + configuration + "_" + type]
elif num_names == 2:
configuration = names[0]
type = names[1]
expected_asset_name = f"t5.{configuration}.{type}.output.pt"
t5_model = BUNDLERS[configuration + "_" + type]
else:
raise RuntimeError(f"Unknown model name: {self.model_name}")

expected_asset_name = f"t5.{configuration}.{type}.output.pt"
test_text = ["Hello world", "Attention rocks!"]
is_jit = name == "jit"
t5_model = BUNDLERS[configuration + "_" + type]
self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)


@parameterized_class(
("model",),
[
("hf_t5_small_encoder",),
("hf_t5_small",),
("hf_t5_small_generation",),
("hf_flan_base_encoder",),
("hf_flan_base",),
("hf_flan_base_generation",),
],
)
class TestLoadFromHFCheckpoints(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])
self.encoder_padding_mask = torch.tensor(
[[False, False, False, False, False, False], [False, False, False, True, True, True]]
)
self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])

def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None:
# check that encoder layers match
for i in range(config.num_encoder_layers + 1):
if i < config.num_encoder_layers:
hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i]
# self-attention scores
assert torch.equal(
our_output["encoder_sa_scores"][i], hf_output_sa
), f"Mismatched self-attention scores for encoder layer {i}"
hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i]
# encoder hidden states
assert torch.equal(
our_output["encoder_hidden_states"][i], hf_output_hs
), f"Mismatched hidden states for encoder layer {i}"

if not encoder_only:
# check that decoder layers match
for i in range(config.num_decoder_layers + 1):
if i < config.num_encoder_layers:
# self-attention scores
assert torch.equal(
our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i]
), f"Mismatched self-attention scores for decoder layer {i}"
# cross-attention scores
assert torch.equal(
our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i]
), f"Mismatched cross-attention scores for decoder layer {i}"
# decoder hidden states
assert torch.equal(
our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
), f"Mismatched hidden states for decoder layer {i}"

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_enc"

t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
t5_small_enc.save_pretrained(model_path)

our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path, encoder_only=True)

hf_output = t5_small_enc(
input_ids=self.encoder_input_ids,
attention_mask=self.encoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_encoder(self.encoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_encoder.config, True)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small"

t5_small = T5Model.from_pretrained("t5-small")
t5_small.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, False)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_gen"

t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_small_gen.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small_gen(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, False)

def test_flan_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
# TODO(joecummings): Download FLAN-T5 chkpts and test here
pass
self.decoder_padding_mask = torch.tensor(
[[False, False, False, True, True, True], [False, False, False, True, True, True]]
)

def test_t5_bundler_load_hf_ckpt_pretrained(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
local_path = f"{tmp}/{self.model}"
remote_bucket = f"{_TEXT_BUCKET}test_models"

os.mkdir(local_path)

for f in {"config.json", "pytorch_model.bin"}:
destination = f"{local_path}/{f}"
remote_path = f"{remote_bucket}/{self.model}/{f}"
_TEST_DOWNLOAD_MANAGER.get_local_path(url=remote_path, destination=destination)

names = self.model.split("_")
is_encoder_only = names[-1] == "encoder"

model = T5Bundle.build_model_from_huggingface_ckpt(local_path, encoder_only=is_encoder_only)
if is_encoder_only:
model(self.encoder_input_ids, encoder_padding_mask=self.encoder_padding_mask)
else:
model(
self.encoder_input_ids,
self.decoder_input_ids,
encoder_padding_mask=self.encoder_padding_mask,
decoder_padding_mask=self.decoder_padding_mask,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 1 addition & 0 deletions torchtext/_download_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ def get_local_path(self, url, destination):


_DATASET_DOWNLOAD_MANAGER = DownloadManager()
_TEST_DOWNLOAD_MANAGER = DownloadManager()
6 changes: 0 additions & 6 deletions torchtext/models/t5/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .bundler import (
FLAN_T5_SMALL_ENCODER,
FLAN_T5_SMALL,
FLAN_T5_SMALL_GENERATION,
FLAN_T5_BASE_ENCODER,
FLAN_T5_BASE,
FLAN_T5_BASE_GENERATION,
Expand Down Expand Up @@ -53,9 +50,6 @@
"T5_11B_ENCODER",
"T5_11B",
"T5_11B_GENERATION",
"FLAN_T5_SMALL_ENCODER",
"FLAN_T5_SMALL",
"FLAN_T5_SMALL_GENERATION",
"FLAN_T5_BASE_ENCODER",
"FLAN_T5_BASE",
"FLAN_T5_BASE_GENERATION",
Expand Down
61 changes: 6 additions & 55 deletions torchtext/models/t5/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def build_model_from_huggingface_ckpt(
"""Build T5Model model from a HuggingFace checkpoint.

Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX.
This also requires a fully saved model, sharded checkpoints are not supported.

Args:
ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local.
Expand Down Expand Up @@ -238,12 +239,12 @@ def build_model_from_huggingface_ckpt(

for i in range(config.num_decoder_layers):
if config.is_gated_act:
t5_model_state_dict[f"encoder.layers.{i}.linear1_0.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.DenseReluDense.wi_0.weight"
t5_model_state_dict[f"decoder.layers.{i}.linear1_0.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight"
]

t5_model_state_dict[f"encoder.layers.{i}.linear1_1.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.DenseReluDense.wi_1.weight"
t5_model_state_dict[f"decoder.layers.{i}.linear1_1.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight"
]
else:
t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[
Expand Down Expand Up @@ -650,56 +651,6 @@ def t5_transform() -> T5Transform:

T5_11B_GENERATION.__doc__ = GENERATION_DOC.format("11B", "11B")


FLAN_T5_SMALL_ENCODER = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.encoder.pt"),
_config=T5Conf(
encoder_only=True,
embedding_dim=512,
num_attention_heads=6,
num_encoder_layers=8,
num_decoder_layers=8,
ffn_dimension=1024,
feed_forward_proj="gated-gelu",
),
transform=t5_transform,
)

FLAN_T5_SMALL_ENCODER.__doc__ = FLAN_ENCODER_DOC.format("SMALL", "SMALL")

FLAN_T5_SMALL = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.pt"),
_config=T5Conf(
encoder_only=False,
embedding_dim=512,
num_attention_heads=6,
num_encoder_layers=8,
num_decoder_layers=8,
ffn_dimension=1024,
feed_forward_proj="gated-gelu",
),
transform=t5_transform,
)

FLAN_T5_SMALL.__doc__ = FLAN_DOC.format("SMALL", "SMALL")

FLAN_T5_SMALL_GENERATION = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.flan.small.generation.pt"),
_config=T5Conf(
encoder_only=False,
linear_head=True,
embedding_dim=512,
num_attention_heads=6,
num_encoder_layers=8,
num_decoder_layers=8,
ffn_dimension=1024,
feed_forward_proj="gated-gelu",
),
transform=t5_transform,
)

FLAN_T5_SMALL_GENERATION.__doc__ = FLAN_GENERATION_DOC.format("SMALL", "SMALL")

FLAN_T5_BASE_ENCODER = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.flan.base.encoder.pt"),
_config=T5Conf(encoder_only=True, ffn_dimension=2048, feed_forward_proj="gated-gelu"),
Expand Down Expand Up @@ -762,7 +713,7 @@ def t5_transform() -> T5Transform:


FLAN_T5_LARGE_GENERATION = T5Bundle(
_path=urljoin(_TEXT_BUCKET, "t5.flan.large.encoder.pt"),
_path=urljoin(_TEXT_BUCKET, "t5.flan.large.generation.pt"),
_config=T5Conf(
encoder_only=False,
linear_head=True,
Expand Down