This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
Beginning of generation utils and necessary refactors of T5 Model #2011
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c008115
Separate encoding/decoding logic for T5 model in preparation for gene…
joecummings b699de2
Add generation utils with greedy search and tests
joecummings 6860a30
Fix linting issues
joecummings 63a3020
Update docstring
joecummings d10043d
Update generate test to fix warning test failure
joecummings 25f18ef
Add default max seq len to logging and test
joecummings File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "### Ensuring the TorchText T5 implementation matches other OSS implementations\n", | ||
| "\n", | ||
| "> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 29, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from transformers import T5Model\n", | ||
| "from torchtext.prototype.models import T5_BASE\n", | ||
| "\n", | ||
| "import torch" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 30, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "input_sentence = [\"translate to Spanish: My name is Joe\"]\n", | ||
| "output_sentence = [\"Me llamo Joe\"]\n", | ||
| "\n", | ||
| "transform = T5_BASE.transform()\n", | ||
| "tt_t5_model = T5_BASE.get_model()\n", | ||
| "\n", | ||
| "hf_t5_model = T5Model.from_pretrained(\"t5-base\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 31, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "tokenized_sentence = transform(input_sentence)\n", | ||
| "tokenized_output = transform(output_sentence)\n", | ||
| "\n", | ||
| "tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n", | ||
| "hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n", | ||
| "\n", | ||
| "assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n", | ||
| "assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3.9.13 ('torchtext39')", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.13" | ||
| }, | ||
| "orig_nbformat": 4, | ||
| "vscode": { | ||
| "interpreter": { | ||
| "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" | ||
| } | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 2 | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 1, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
| " from .autonotebook import tqdm as notebook_tqdm\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n", | ||
| "from torchtext.prototype.generate import GenerationUtil" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 2, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n", | ||
| "bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n", | ||
| "gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 3, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stderr", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", | ||
| "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", | ||
| "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", | ||
| "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", | ||
| "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", | ||
| " warnings.warn(\n" | ||
| ] | ||
| }, | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "['owning a dog is good for you, according to studies. a dog is']\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "# Testing Huggingface's T5\n", | ||
| "test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n", | ||
| "generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n", | ||
| "t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n", | ||
| "test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
| "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n", | ||
| "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 4, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "# Testing Huggingface's BART\n", | ||
| "test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n", | ||
| " \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n", | ||
| " \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n", | ||
| "generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n", | ||
| "bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n", | ||
| "test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
| "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n", | ||
| "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": 5, | ||
| "metadata": {}, | ||
| "outputs": [ | ||
| { | ||
| "name": "stdout", | ||
| "output_type": "stream", | ||
| "text": [ | ||
| "[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n" | ||
| ] | ||
| } | ||
| ], | ||
| "source": [ | ||
| "# Testing Huggingface's GPT2\n", | ||
| "test_sequence = [\"I enjoy walking with my cute dog\"]\n", | ||
| "generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n", | ||
| "gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", | ||
| "test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
| "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n", | ||
| "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "kernelspec": { | ||
| "display_name": "Python 3.9.13 ('torchtext39')", | ||
| "language": "python", | ||
| "name": "python3" | ||
| }, | ||
| "language_info": { | ||
| "codemirror_mode": { | ||
| "name": "ipython", | ||
| "version": 3 | ||
| }, | ||
| "file_extension": ".py", | ||
| "mimetype": "text/x-python", | ||
| "name": "python", | ||
| "nbconvert_exporter": "python", | ||
| "pygments_lexer": "ipython3", | ||
| "version": "3.9.13" | ||
| }, | ||
| "orig_nbformat": 4, | ||
| "vscode": { | ||
| "interpreter": { | ||
| "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" | ||
| } | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 2 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| from unittest.mock import patch | ||
|
|
||
| import torch | ||
| from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil | ||
| from torchtext.prototype.models import T5_BASE_GENERATION | ||
| from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase | ||
|
|
||
|
|
||
| class TestGenerationUtil(TorchtextTestCase): | ||
| def setUp(self) -> None: | ||
| super().setUp() | ||
| t5_base = T5_BASE_GENERATION | ||
| self.transform = t5_base.transform() | ||
| self.model = t5_base.get_model() | ||
| self.model.eval() | ||
| # Examples taken from T5 Paper and Huggingface | ||
| self.inputs = self.transform( | ||
| [ | ||
| "summarize: studies have shown that owning a dog is good for you", | ||
| "translate English to German: That is good.", | ||
| "cola sentence: The course is jumping well.", | ||
| "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", | ||
| "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", | ||
| ] | ||
| ) | ||
| torch.manual_seed(0) | ||
|
|
||
| def test_greedy_generate_with_t5(self) -> None: | ||
| generation_model = GenerationUtil(self.model) | ||
|
|
||
| tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30) | ||
| generated_text = self.transform.decode(tokens.tolist()) | ||
|
|
||
| expected_generated_text = [ | ||
| "a dog is good for you, according to studies . owning a dog is good for you, according to studies .", | ||
| "Das ist gut.", | ||
| "acceptable", | ||
| "4.0", | ||
| "mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", | ||
| ] | ||
|
|
||
| self.assertEqual(generated_text, expected_generated_text) | ||
|
|
||
| def test_generate_errors_with_incorrect_beams(self) -> None: | ||
| generation_model = GenerationUtil(self.model, is_encoder_decoder=True) | ||
|
|
||
| with self.assertRaises(ValueError): | ||
| generation_model.generate(self.inputs, num_beams=0) | ||
|
|
||
| @patch("logging.Logger.warning") | ||
| def test_warns_when_no_max_len_provided(self, mock) -> None: | ||
| generation_model = GenerationUtil(self.model) | ||
| generation_model.generate(self.inputs) | ||
| mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know we verify completeness often w/ internal notebooks - I thought for those that show parity with HuggingFace or external libraries, we could put those notebooks in the actual repo. Seems like a better way to keep track rather than some Bento notebooks w/ scattered ownership.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you upload the notebook to Github gist and provide a link in the PR so it's easier to review the contents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, but for a quick fix you can right click on expand dots on the top right of this file and select "View file" and it'll give you a notebook view.