Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit a933cbe

Browse files
authored
Beginning of generation utils and necessary refactors of T5 Model (#2011)
* Separate encoding/decoding logic for T5 model in preparation for generation * Add generation utils with greedy search and tests
1 parent f653dac commit a933cbe

File tree

10 files changed

+614
-134
lines changed

10 files changed

+614
-134
lines changed

notebooks/hf_vs_tt_t5.ipynb

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### Ensuring the TorchText T5 implementation matches other OSS implementations\n",
8+
"\n",
9+
"> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 29,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"from transformers import T5Model\n",
19+
"from torchtext.prototype.models import T5_BASE\n",
20+
"\n",
21+
"import torch"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 30,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"input_sentence = [\"translate to Spanish: My name is Joe\"]\n",
31+
"output_sentence = [\"Me llamo Joe\"]\n",
32+
"\n",
33+
"transform = T5_BASE.transform()\n",
34+
"tt_t5_model = T5_BASE.get_model()\n",
35+
"\n",
36+
"hf_t5_model = T5Model.from_pretrained(\"t5-base\")"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 31,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"tokenized_sentence = transform(input_sentence)\n",
46+
"tokenized_output = transform(output_sentence)\n",
47+
"\n",
48+
"tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n",
49+
"hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n",
50+
"\n",
51+
"assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n",
52+
"assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))"
53+
]
54+
}
55+
],
56+
"metadata": {
57+
"kernelspec": {
58+
"display_name": "Python 3.9.13 ('torchtext39')",
59+
"language": "python",
60+
"name": "python3"
61+
},
62+
"language_info": {
63+
"codemirror_mode": {
64+
"name": "ipython",
65+
"version": 3
66+
},
67+
"file_extension": ".py",
68+
"mimetype": "text/x-python",
69+
"name": "python",
70+
"nbconvert_exporter": "python",
71+
"pygments_lexer": "ipython3",
72+
"version": "3.9.13"
73+
},
74+
"orig_nbformat": 4,
75+
"vscode": {
76+
"interpreter": {
77+
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
78+
}
79+
}
80+
},
81+
"nbformat": 4,
82+
"nbformat_minor": 2
83+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [
15+
{
16+
"name": "stderr",
17+
"output_type": "stream",
18+
"text": [
19+
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
20+
" from .autonotebook import tqdm as notebook_tqdm\n"
21+
]
22+
}
23+
],
24+
"source": [
25+
"from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n",
26+
"from torchtext.prototype.generate import GenerationUtil"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": 2,
32+
"metadata": {},
33+
"outputs": [],
34+
"source": [
35+
"t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n",
36+
"bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n",
37+
"gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 3,
43+
"metadata": {},
44+
"outputs": [
45+
{
46+
"name": "stderr",
47+
"output_type": "stream",
48+
"text": [
49+
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
50+
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
51+
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
52+
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
53+
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
54+
" warnings.warn(\n"
55+
]
56+
},
57+
{
58+
"name": "stdout",
59+
"output_type": "stream",
60+
"text": [
61+
"['owning a dog is good for you, according to studies. a dog is']\n"
62+
]
63+
}
64+
],
65+
"source": [
66+
"# Testing Huggingface's T5\n",
67+
"test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n",
68+
"generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n",
69+
"t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n",
70+
"test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
71+
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n",
72+
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": 4,
78+
"metadata": {},
79+
"outputs": [
80+
{
81+
"name": "stdout",
82+
"output_type": "stream",
83+
"text": [
84+
"['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n"
85+
]
86+
}
87+
],
88+
"source": [
89+
"# Testing Huggingface's BART\n",
90+
"test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n",
91+
" \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n",
92+
" \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n",
93+
"generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n",
94+
"bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
95+
"test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
96+
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n",
97+
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 5,
103+
"metadata": {},
104+
"outputs": [
105+
{
106+
"name": "stdout",
107+
"output_type": "stream",
108+
"text": [
109+
"[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n"
110+
]
111+
}
112+
],
113+
"source": [
114+
"# Testing Huggingface's GPT2\n",
115+
"test_sequence = [\"I enjoy walking with my cute dog\"]\n",
116+
"generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n",
117+
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
118+
"test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
119+
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
120+
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
121+
]
122+
}
123+
],
124+
"metadata": {
125+
"kernelspec": {
126+
"display_name": "Python 3.9.13 ('torchtext39')",
127+
"language": "python",
128+
"name": "python3"
129+
},
130+
"language_info": {
131+
"codemirror_mode": {
132+
"name": "ipython",
133+
"version": 3
134+
},
135+
"file_extension": ".py",
136+
"mimetype": "text/x-python",
137+
"name": "python",
138+
"nbconvert_exporter": "python",
139+
"pygments_lexer": "ipython3",
140+
"version": "3.9.13"
141+
},
142+
"orig_nbformat": 4,
143+
"vscode": {
144+
"interpreter": {
145+
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
146+
}
147+
}
148+
},
149+
"nbformat": 4,
150+
"nbformat_minor": 2
151+
}

test/integration_tests/prototype/test_models.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,22 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
6666

6767
model_input = transform(test_text)
6868
if model.encoder_only:
69-
actual = model(model_input)["encoder_output"]
69+
actual = model(encoder_tokens=model_input)["encoder_output"]
70+
if not is_jit:
71+
self._t5_get_encoder(model, model_input, actual)
7072
else:
71-
actual = model(model_input)["decoder_output"]
73+
actual = model(encoder_tokens=model_input)["decoder_output"]
7274

7375
expected = torch.load(expected_asset_path)
7476
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)
7577

78+
def _t5_get_encoder(self, model, model_input, encoder_output):
79+
encoder = model.get_encoder()
80+
# Need to set the tgt_key_padding_mask to ensure the same results
81+
encoder_padding_mask = model_input.eq(model.padding_idx)
82+
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"]
83+
assert torch.all(output_from_get_encoder.eq(encoder_output))
84+
7685
@nested_params(["jit", "not_jit"])
7786
def test_t5_model(self, name) -> None:
7887
configuration, type = self.model_name.split("_")
@@ -93,7 +102,8 @@ def test_t5_model(self, name) -> None:
93102
],
94103
)
95104
class TestT5Wrapper(TorchtextTestCase):
96-
@parameterized.expand(["jit", "not_jit"])
105+
# No longer Torchscriptable
106+
@parameterized.expand(["no_jit"])
97107
def test_t5_wrapper(self, name) -> None:
98108
configuration = self.configuration
99109
test_text = ["translate English to French: I want to eat pizza for dinner."]
@@ -113,7 +123,8 @@ def test_t5_wrapper(self, name) -> None:
113123

114124

115125
class TestT5WrapperCheckpoint(TorchtextTestCase):
116-
@parameterized.expand(["jit", "not_jit"])
126+
# No longer Torchscriptable
127+
@parameterized.expand(["no_jit"])
117128
def test_t5_wrapper_checkpoint(self, name) -> None:
118129
test_text = ["translate English to French: I want to eat pizza for dinner."]
119130
expected_text = ["Je veux manger de la pizza pour le dîner."]
@@ -127,7 +138,7 @@ def test_t5_wrapper_checkpoint(self, name) -> None:
127138
padding_idx=0,
128139
)
129140
model = T5Wrapper(
130-
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.pt",
141+
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.v2.pt",
131142
t5_config=config,
132143
transform=transform,
133144
freeze_model=True,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from unittest.mock import patch
2+
3+
import torch
4+
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil
5+
from torchtext.prototype.models import T5_BASE_GENERATION
6+
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
7+
8+
9+
class TestGenerationUtil(TorchtextTestCase):
10+
def setUp(self) -> None:
11+
super().setUp()
12+
t5_base = T5_BASE_GENERATION
13+
self.transform = t5_base.transform()
14+
self.model = t5_base.get_model()
15+
self.model.eval()
16+
# Examples taken from T5 Paper and Huggingface
17+
self.inputs = self.transform(
18+
[
19+
"summarize: studies have shown that owning a dog is good for you",
20+
"translate English to German: That is good.",
21+
"cola sentence: The course is jumping well.",
22+
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
23+
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
24+
]
25+
)
26+
torch.manual_seed(0)
27+
28+
def test_greedy_generate_with_t5(self) -> None:
29+
generation_model = GenerationUtil(self.model)
30+
31+
tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
32+
generated_text = self.transform.decode(tokens.tolist())
33+
34+
expected_generated_text = [
35+
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
36+
"Das ist gut.",
37+
"acceptable",
38+
"4.0",
39+
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
40+
]
41+
42+
self.assertEqual(generated_text, expected_generated_text)
43+
44+
def test_generate_errors_with_incorrect_beams(self) -> None:
45+
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)
46+
47+
with self.assertRaises(ValueError):
48+
generation_model.generate(self.inputs, num_beams=0)
49+
50+
@patch("logging.Logger.warning")
51+
def test_warns_when_no_max_len_provided(self, mock) -> None:
52+
generation_model = GenerationUtil(self.model)
53+
generation_model.generate(self.inputs)
54+
mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")

0 commit comments

Comments
 (0)