Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
run: |
python -m pip install --quiet --upgrade pip
python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest transformers
python setup.py install
- name: Run integration test
run: |
Expand Down
110 changes: 109 additions & 1 deletion test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import tempfile

import pytest # noqa: F401
import torch
from parameterized import parameterized, parameterized_class
Expand All @@ -14,11 +16,12 @@
T5Conf,
T5Transform,
)
from torchtext.prototype.models.t5.bundler import T5Bundle
from torchtext.prototype.models.t5.wrapper import T5Wrapper
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.parameterized_utils import nested_params
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase

from transformers import T5Model, T5EncoderModel, T5ForConditionalGeneration

BUNDLERS = {
"base_model": T5_BASE,
Expand Down Expand Up @@ -135,3 +138,108 @@ def test_t5_wrapper_checkpoint(self, name) -> None:

output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)


class TestLoadFromHFCheckpoints(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])
self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])

def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None:
# check that encoder layers match
for i in range(config.num_encoder_layers + 1):
if i < config.num_encoder_layers:
hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i]
# self-attention scores
assert torch.equal(
our_output["encoder_sa_scores"][i], hf_output_sa
), f"Mismatched self-attention scores for encoder layer {i}"
hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i]
# encoder hidden states
assert torch.equal(
our_output["encoder_hidden_states"][i], hf_output_hs
), f"Mismatched hidden states for encoder layer {i}"

if not encoder_only:
# check that decoder layers match
for i in range(config.num_decoder_layers + 1):
if i < config.num_encoder_layers:
# self-attention scores
assert torch.equal(
our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i]
), f"Mismatched self-attention scores for decoder layer {i}"
# cross-attention scores
assert torch.equal(
our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i]
), f"Mismatched cross-attention scores for decoder layer {i}"
# decoder hidden states
assert torch.equal(
our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
), f"Mismatched hidden states for decoder layer {i}"

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_enc"

t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
t5_small_enc.save_pretrained(model_path)

our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small_enc(
input_ids=self.encoder_input_ids,
attention_mask=self.encoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_encoder(self.encoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small"

t5_small = T5Model.from_pretrained("t5-small")
t5_small.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)

def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
model_path = f"{tmp_dir}/hf_t5_small_gen"

t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
t5_small_gen.save_pretrained(model_path)

our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)

hf_output = t5_small_gen(
input_ids=self.encoder_input_ids,
decoder_input_ids=self.decoder_input_ids,
attention_mask=self.encoder_padding_mask,
decoder_attention_mask=self.decoder_padding_mask,
output_hidden_states=True,
output_attentions=True,
)

our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)

self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)
134 changes: 134 additions & 0 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union
from urllib.parse import urljoin
Expand Down Expand Up @@ -133,6 +135,138 @@ def build_model(

return model

@staticmethod
def build_model_from_huggingface_ckpt(
ckpt_path: Union[str, os.PathLike],
*,
freeze_model: bool = False,
strict: bool = True,
) -> T5Model:
"""Build T5Model model from a HuggingFace checkpoint.

Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX.

Args:
ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local.
freeze_model (bool): Freeze the model upon loading. (Default: `False`)
strict (bool): Load model in strict mode. (Default: `True`)

Returns:
T5Model loaded with the weights of the HuggingFace checkpoint provided
"""
config_path = f"{ckpt_path}/config.json"
model_path = f"{ckpt_path}/pytorch_model.bin"

with open(config_path, "r") as handle:
config_json = json.load(handle)
hf_weights = torch.load(model_path)

# TODO(joecummings): find better way to determine `encoder_only` and `linear_head`
config = T5Conf(
encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(),
linear_head="lm_head.weight" in hf_weights.keys(),
embedding_dim=config_json["d_model"],
num_attention_heads=config_json["num_heads"],
num_encoder_layers=config_json["num_layers"],
num_decoder_layers=config_json["num_decoder_layers"],
ffn_dimension=config_json["d_ff"],
)

t5_model = T5Model(config, freeze_model)

t5_model_state_dict = {
"token_embeddings.weight": hf_weights["shared.weight"],
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"],
"encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
],
}
# Convert encoder layers
for i in range(config.num_encoder_layers):
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.norm1.weight"] = hf_weights[
f"encoder.block.{i}.layer.0.layer_norm.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.norm2.weight"] = hf_weights[
f"encoder.block.{i}.layer.1.layer_norm.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.o.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.q.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.k.weight"
]
t5_model_state_dict[f"encoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
f"encoder.block.{i}.layer.0.SelfAttention.v.weight"
]

# Convert decoder layers if model is encoder-decoder
if not config.encoder_only:
t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"]
t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
]

for i in range(config.num_decoder_layers):
t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.linear2.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm1.weight"] = hf_weights[
f"decoder.block.{i}.layer.0.layer_norm.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm2.weight"] = hf_weights[
f"decoder.block.{i}.layer.2.layer_norm.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.norm3.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.layer_norm.weight"
]

t5_model_state_dict[f"decoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.o.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.q.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.k.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.0.SelfAttention.v.weight"
]

t5_model_state_dict[f"decoder.layers.{i}.cross_attn.out_proj.weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.q_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.k_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"
]
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.v_proj_weight"] = hf_weights[
f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"
]

# Convert language modeling head if there is one
if config.linear_head:
t5_model_state_dict["lm_head.weight"] = hf_weights["lm_head.weight"]

# Load state dict into our model
t5_model.load_state_dict(t5_model_state_dict, strict)

return t5_model

@property
def config(self) -> T5Conf:
return self._config
Expand Down