Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit de54db6

Browse files
authored
[Feature] Add ability to load HF checkpoints into T5 model (#1918)
* Add ability to load HF checkpoints into T5 model * Add HuggingFace to integrations tests * Remove duplicate code * Revert fix * Add setup * Remove ability to download from remote URL * Remove line break from docstring
1 parent 3f9c349 commit de54db6

File tree

3 files changed

+244
-2
lines changed

3 files changed

+244
-2
lines changed

.github/workflows/integration-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
run: |
2525
python -m pip install --quiet --upgrade pip
2626
python -m pip install --quiet --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
27-
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest
27+
python -m pip install --quiet pytest requests cmake ninja sentencepiece parameterized tqdm expecttest transformers
2828
python setup.py install
2929
- name: Run integration test
3030
run: |

test/integration_tests/prototype/test_models.py

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import tempfile
2+
13
import pytest # noqa: F401
24
import torch
35
from parameterized import parameterized, parameterized_class
@@ -14,11 +16,12 @@
1416
T5Conf,
1517
T5Transform,
1618
)
19+
from torchtext.prototype.models.t5.bundler import T5Bundle
1720
from torchtext.prototype.models.t5.wrapper import T5Wrapper
1821
from torchtext_unittest.common.assets import get_asset_path
1922
from torchtext_unittest.common.parameterized_utils import nested_params
2023
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
21-
24+
from transformers import T5Model, T5EncoderModel, T5ForConditionalGeneration
2225

2326
BUNDLERS = {
2427
"base_model": T5_BASE,
@@ -135,3 +138,108 @@ def test_t5_wrapper_checkpoint(self, name) -> None:
135138

136139
output_text = model(test_text, beam_size, max_seq_len)
137140
self.assertEqual(output_text, expected_text)
141+
142+
143+
class TestLoadFromHFCheckpoints(TorchtextTestCase):
144+
def setUp(self) -> None:
145+
super().setUp()
146+
self.encoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 0, 0, 0]])
147+
self.encoder_padding_mask = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]])
148+
self.decoder_input_ids = torch.tensor([[7, 8, 9, 0, 0, 0], [10, 11, 12, 0, 0, 0]])
149+
self.decoder_padding_mask = torch.tensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])
150+
151+
def check_outputs_of_models(self, our_output, hf_output, config, encoder_only) -> None:
152+
# check that encoder layers match
153+
for i in range(config.num_encoder_layers + 1):
154+
if i < config.num_encoder_layers:
155+
hf_output_sa = hf_output.attentions[i] if encoder_only else hf_output.encoder_attentions[i]
156+
# self-attention scores
157+
assert torch.equal(
158+
our_output["encoder_sa_scores"][i], hf_output_sa
159+
), f"Mismatched self-attention scores for encoder layer {i}"
160+
hf_output_hs = hf_output.hidden_states[i] if encoder_only else hf_output.encoder_hidden_states[i]
161+
# encoder hidden states
162+
assert torch.equal(
163+
our_output["encoder_hidden_states"][i], hf_output_hs
164+
), f"Mismatched hidden states for encoder layer {i}"
165+
166+
if not encoder_only:
167+
# check that decoder layers match
168+
for i in range(config.num_decoder_layers + 1):
169+
if i < config.num_encoder_layers:
170+
# self-attention scores
171+
assert torch.equal(
172+
our_output["decoder_sa_scores"][i], hf_output.decoder_attentions[i]
173+
), f"Mismatched self-attention scores for decoder layer {i}"
174+
# cross-attention scores
175+
assert torch.equal(
176+
our_output["decoder_ca_scores"][i], hf_output.cross_attentions[i]
177+
), f"Mismatched cross-attention scores for decoder layer {i}"
178+
# decoder hidden states
179+
assert torch.equal(
180+
our_output["decoder_hidden_states"][i], hf_output.decoder_hidden_states[i]
181+
), f"Mismatched hidden states for decoder layer {i}"
182+
183+
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_only(self) -> None:
184+
with tempfile.TemporaryDirectory() as tmp_dir:
185+
model_path = f"{tmp_dir}/hf_t5_small_enc"
186+
187+
t5_small_enc = T5EncoderModel.from_pretrained("t5-small")
188+
t5_small_enc.save_pretrained(model_path)
189+
190+
our_encoder = T5Bundle.build_model_from_huggingface_ckpt(model_path)
191+
192+
hf_output = t5_small_enc(
193+
input_ids=self.encoder_input_ids,
194+
attention_mask=self.encoder_padding_mask,
195+
output_hidden_states=True,
196+
output_attentions=True,
197+
)
198+
199+
our_output = our_encoder(self.encoder_input_ids)
200+
201+
self.check_outputs_of_models(our_output, hf_output, our_encoder.config, encoder_only=True)
202+
203+
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder(self) -> None:
204+
with tempfile.TemporaryDirectory() as tmp_dir:
205+
model_path = f"{tmp_dir}/hf_t5_small"
206+
207+
t5_small = T5Model.from_pretrained("t5-small")
208+
t5_small.save_pretrained(model_path)
209+
210+
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
211+
212+
hf_output = t5_small(
213+
input_ids=self.encoder_input_ids,
214+
decoder_input_ids=self.decoder_input_ids,
215+
attention_mask=self.encoder_padding_mask,
216+
decoder_attention_mask=self.decoder_padding_mask,
217+
output_hidden_states=True,
218+
output_attentions=True,
219+
)
220+
221+
our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)
222+
223+
self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)
224+
225+
def test_t5_bundler_load_hf_ckpt_pretrained_encoder_decoder_with_gen(self) -> None:
226+
with tempfile.TemporaryDirectory() as tmp_dir:
227+
model_path = f"{tmp_dir}/hf_t5_small_gen"
228+
229+
t5_small_gen = T5ForConditionalGeneration.from_pretrained("t5-small")
230+
t5_small_gen.save_pretrained(model_path)
231+
232+
our_t5 = T5Bundle.build_model_from_huggingface_ckpt(model_path)
233+
234+
hf_output = t5_small_gen(
235+
input_ids=self.encoder_input_ids,
236+
decoder_input_ids=self.decoder_input_ids,
237+
attention_mask=self.encoder_padding_mask,
238+
decoder_attention_mask=self.decoder_padding_mask,
239+
output_hidden_states=True,
240+
output_attentions=True,
241+
)
242+
243+
our_output = our_t5(self.encoder_input_ids, self.decoder_input_ids)
244+
245+
self.check_outputs_of_models(our_output, hf_output, our_t5.config, encoder_only=False)

torchtext/prototype/models/t5/bundler.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import logging
3+
import os
24
from dataclasses import dataclass
35
from typing import Any, Callable, Dict, Optional, Union
46
from urllib.parse import urljoin
@@ -133,6 +135,138 @@ def build_model(
133135

134136
return model
135137

138+
@staticmethod
139+
def build_model_from_huggingface_ckpt(
140+
ckpt_path: Union[str, os.PathLike],
141+
*,
142+
freeze_model: bool = False,
143+
strict: bool = True,
144+
) -> T5Model:
145+
"""Build T5Model model from a HuggingFace checkpoint.
146+
147+
Note: Only works with Huggingface models saved in the PyTorch format. Will not work with TensorFlow or JAX.
148+
149+
Args:
150+
ckpt_path (str, Path): Path to the HF checkpoint file. Assumes that the file is local.
151+
freeze_model (bool): Freeze the model upon loading. (Default: `False`)
152+
strict (bool): Load model in strict mode. (Default: `True`)
153+
154+
Returns:
155+
T5Model loaded with the weights of the HuggingFace checkpoint provided
156+
"""
157+
config_path = f"{ckpt_path}/config.json"
158+
model_path = f"{ckpt_path}/pytorch_model.bin"
159+
160+
with open(config_path, "r") as handle:
161+
config_json = json.load(handle)
162+
hf_weights = torch.load(model_path)
163+
164+
# TODO(joecummings): find better way to determine `encoder_only` and `linear_head`
165+
config = T5Conf(
166+
encoder_only="decoder.final_layer_norm.weight" not in hf_weights.keys(),
167+
linear_head="lm_head.weight" in hf_weights.keys(),
168+
embedding_dim=config_json["d_model"],
169+
num_attention_heads=config_json["num_heads"],
170+
num_encoder_layers=config_json["num_layers"],
171+
num_decoder_layers=config_json["num_decoder_layers"],
172+
ffn_dimension=config_json["d_ff"],
173+
)
174+
175+
t5_model = T5Model(config, freeze_model)
176+
177+
t5_model_state_dict = {
178+
"token_embeddings.weight": hf_weights["shared.weight"],
179+
"norm1.weight": hf_weights["encoder.final_layer_norm.weight"],
180+
"encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[
181+
"encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
182+
],
183+
}
184+
# Convert encoder layers
185+
for i in range(config.num_encoder_layers):
186+
t5_model_state_dict[f"encoder.layers.{i}.linear1.weight"] = hf_weights[
187+
f"encoder.block.{i}.layer.1.DenseReluDense.wi.weight"
188+
]
189+
t5_model_state_dict[f"encoder.layers.{i}.linear2.weight"] = hf_weights[
190+
f"encoder.block.{i}.layer.1.DenseReluDense.wo.weight"
191+
]
192+
t5_model_state_dict[f"encoder.layers.{i}.norm1.weight"] = hf_weights[
193+
f"encoder.block.{i}.layer.0.layer_norm.weight"
194+
]
195+
t5_model_state_dict[f"encoder.layers.{i}.norm2.weight"] = hf_weights[
196+
f"encoder.block.{i}.layer.1.layer_norm.weight"
197+
]
198+
t5_model_state_dict[f"encoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
199+
f"encoder.block.{i}.layer.0.SelfAttention.o.weight"
200+
]
201+
t5_model_state_dict[f"encoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
202+
f"encoder.block.{i}.layer.0.SelfAttention.q.weight"
203+
]
204+
t5_model_state_dict[f"encoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
205+
f"encoder.block.{i}.layer.0.SelfAttention.k.weight"
206+
]
207+
t5_model_state_dict[f"encoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
208+
f"encoder.block.{i}.layer.0.SelfAttention.v.weight"
209+
]
210+
211+
# Convert decoder layers if model is encoder-decoder
212+
if not config.encoder_only:
213+
t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"]
214+
t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[
215+
"decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
216+
]
217+
218+
for i in range(config.num_decoder_layers):
219+
t5_model_state_dict[f"decoder.layers.{i}.linear1.weight"] = hf_weights[
220+
f"decoder.block.{i}.layer.2.DenseReluDense.wi.weight"
221+
]
222+
t5_model_state_dict[f"decoder.layers.{i}.linear2.weight"] = hf_weights[
223+
f"decoder.block.{i}.layer.2.DenseReluDense.wo.weight"
224+
]
225+
t5_model_state_dict[f"decoder.layers.{i}.norm1.weight"] = hf_weights[
226+
f"decoder.block.{i}.layer.0.layer_norm.weight"
227+
]
228+
t5_model_state_dict[f"decoder.layers.{i}.norm2.weight"] = hf_weights[
229+
f"decoder.block.{i}.layer.2.layer_norm.weight"
230+
]
231+
t5_model_state_dict[f"decoder.layers.{i}.norm3.weight"] = hf_weights[
232+
f"decoder.block.{i}.layer.1.layer_norm.weight"
233+
]
234+
235+
t5_model_state_dict[f"decoder.layers.{i}.self_attn.out_proj.weight"] = hf_weights[
236+
f"decoder.block.{i}.layer.0.SelfAttention.o.weight"
237+
]
238+
t5_model_state_dict[f"decoder.layers.{i}.self_attn.q_proj_weight"] = hf_weights[
239+
f"decoder.block.{i}.layer.0.SelfAttention.q.weight"
240+
]
241+
t5_model_state_dict[f"decoder.layers.{i}.self_attn.k_proj_weight"] = hf_weights[
242+
f"decoder.block.{i}.layer.0.SelfAttention.k.weight"
243+
]
244+
t5_model_state_dict[f"decoder.layers.{i}.self_attn.v_proj_weight"] = hf_weights[
245+
f"decoder.block.{i}.layer.0.SelfAttention.v.weight"
246+
]
247+
248+
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.out_proj.weight"] = hf_weights[
249+
f"decoder.block.{i}.layer.1.EncDecAttention.o.weight"
250+
]
251+
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.q_proj_weight"] = hf_weights[
252+
f"decoder.block.{i}.layer.1.EncDecAttention.q.weight"
253+
]
254+
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.k_proj_weight"] = hf_weights[
255+
f"decoder.block.{i}.layer.1.EncDecAttention.k.weight"
256+
]
257+
t5_model_state_dict[f"decoder.layers.{i}.cross_attn.v_proj_weight"] = hf_weights[
258+
f"decoder.block.{i}.layer.1.EncDecAttention.v.weight"
259+
]
260+
261+
# Convert language modeling head if there is one
262+
if config.linear_head:
263+
t5_model_state_dict["lm_head.weight"] = hf_weights["lm_head.weight"]
264+
265+
# Load state dict into our model
266+
t5_model.load_state_dict(t5_model_state_dict, strict)
267+
268+
return t5_model
269+
136270
@property
137271
def config(self) -> T5Conf:
138272
return self._config

0 commit comments

Comments
 (0)