Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions notebooks/hf_vs_tt_t5.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we verify completeness often w/ internal notebooks - I thought for those that show parity with HuggingFace or external libraries, we could put those notebooks in the actual repo. Seems like a better way to keep track rather than some Bento notebooks w/ scattered ownership.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but for a quick fix you can right click on expand dots on the top right of this file and select "View file" and it'll give you a notebook view.

"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ensuring the TorchText T5 implementation matches other OSS implementations\n",
"\n",
"> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5Model\n",
"from torchtext.prototype.models import T5_BASE\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"input_sentence = [\"translate to Spanish: My name is Joe\"]\n",
"output_sentence = [\"Me llamo Joe\"]\n",
"\n",
"transform = T5_BASE.transform()\n",
"tt_t5_model = T5_BASE.get_model()\n",
"\n",
"hf_t5_model = T5Model.from_pretrained(\"t5-base\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"tokenized_sentence = transform(input_sentence)\n",
"tokenized_output = transform(output_sentence)\n",
"\n",
"tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n",
"hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n",
"\n",
"assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n",
"assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
151 changes: 151 additions & 0 deletions notebooks/hf_with_torchtext_gen.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n",
"from torchtext.prototype.generate import GenerationUtil"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n",
"bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n",
"gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['owning a dog is good for you, according to studies. a dog is']\n"
]
}
],
"source": [
"# Testing Huggingface's T5\n",
"test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n",
"generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n",
"t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n",
"test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n"
]
}
],
"source": [
"# Testing Huggingface's BART\n",
"test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n",
" \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n",
" \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n",
"generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n",
"bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
"test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n"
]
}
],
"source": [
"# Testing Huggingface's GPT2\n",
"test_sequence = [\"I enjoy walking with my cute dog\"]\n",
"generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n",
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
21 changes: 16 additions & 5 deletions test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):

model_input = transform(test_text)
if model.encoder_only:
actual = model(model_input)["encoder_output"]
actual = model(encoder_tokens=model_input)["encoder_output"]
if not is_jit:
self._t5_get_encoder(model, model_input, actual)
else:
actual = model(model_input)["decoder_output"]
actual = model(encoder_tokens=model_input)["decoder_output"]

expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)

def _t5_get_encoder(self, model, model_input, encoder_output):
encoder = model.get_encoder()
# Need to set the tgt_key_padding_mask to ensure the same results
encoder_padding_mask = model_input.eq(model.padding_idx)
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"]
assert torch.all(output_from_get_encoder.eq(encoder_output))

@nested_params(["jit", "not_jit"])
def test_t5_model(self, name) -> None:
configuration, type = self.model_name.split("_")
Expand All @@ -93,7 +102,8 @@ def test_t5_model(self, name) -> None:
],
)
class TestT5Wrapper(TorchtextTestCase):
@parameterized.expand(["jit", "not_jit"])
# No longer Torchscriptable
@parameterized.expand(["no_jit"])
def test_t5_wrapper(self, name) -> None:
configuration = self.configuration
test_text = ["translate English to French: I want to eat pizza for dinner."]
Expand All @@ -113,7 +123,8 @@ def test_t5_wrapper(self, name) -> None:


class TestT5WrapperCheckpoint(TorchtextTestCase):
@parameterized.expand(["jit", "not_jit"])
# No longer Torchscriptable
@parameterized.expand(["no_jit"])
def test_t5_wrapper_checkpoint(self, name) -> None:
test_text = ["translate English to French: I want to eat pizza for dinner."]
expected_text = ["Je veux manger de la pizza pour le dîner."]
Expand All @@ -127,7 +138,7 @@ def test_t5_wrapper_checkpoint(self, name) -> None:
padding_idx=0,
)
model = T5Wrapper(
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.pt",
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.v2.pt",
t5_config=config,
transform=transform,
freeze_model=True,
Expand Down
54 changes: 54 additions & 0 deletions test/torchtext_unittest/prototype/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from unittest.mock import patch

import torch
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil
from torchtext.prototype.models import T5_BASE_GENERATION
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase


class TestGenerationUtil(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
t5_base = T5_BASE_GENERATION
self.transform = t5_base.transform()
self.model = t5_base.get_model()
self.model.eval()
# Examples taken from T5 Paper and Huggingface
self.inputs = self.transform(
[
"summarize: studies have shown that owning a dog is good for you",
"translate English to German: That is good.",
"cola sentence: The course is jumping well.",
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
]
)
torch.manual_seed(0)

def test_greedy_generate_with_t5(self) -> None:
generation_model = GenerationUtil(self.model)

tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
"Das ist gut.",
"acceptable",
"4.0",
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
]

self.assertEqual(generated_text, expected_generated_text)

def test_generate_errors_with_incorrect_beams(self) -> None:
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)

with self.assertRaises(ValueError):
generation_model.generate(self.inputs, num_beams=0)

@patch("logging.Logger.warning")
def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtil(self.model)
generation_model.generate(self.inputs)
mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
Loading