diff --git a/notebooks/hf_vs_tt_t5.ipynb b/notebooks/hf_vs_tt_t5.ipynb new file mode 100644 index 0000000000..119cb62e69 --- /dev/null +++ b/notebooks/hf_vs_tt_t5.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Ensuring the TorchText T5 implementation matches other OSS implementations\n", + "\n", + "> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import T5Model\n", + "from torchtext.prototype.models import T5_BASE\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "input_sentence = [\"translate to Spanish: My name is Joe\"]\n", + "output_sentence = [\"Me llamo Joe\"]\n", + "\n", + "transform = T5_BASE.transform()\n", + "tt_t5_model = T5_BASE.get_model()\n", + "\n", + "hf_t5_model = T5Model.from_pretrained(\"t5-base\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "tokenized_sentence = transform(input_sentence)\n", + "tokenized_output = transform(output_sentence)\n", + "\n", + "tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n", + "hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n", + "\n", + "assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n", + "assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('torchtext39')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/hf_with_torchtext_gen.ipynb b/notebooks/hf_with_torchtext_gen.ipynb new file mode 100644 index 0000000000..0df74a4b39 --- /dev/null +++ b/notebooks/hf_with_torchtext_gen.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n", + "from torchtext.prototype.generate import GenerationUtil" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n", + "bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n", + "gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", + "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", + "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", + "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", + "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['owning a dog is good for you, according to studies. a dog is']\n" + ] + } + ], + "source": [ + "# Testing Huggingface's T5\n", + "test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n", + "generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n", + "t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n", + "test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", + "tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n", + "print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n" + ] + } + ], + "source": [ + "# Testing Huggingface's BART\n", + "test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n", + " \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n", + " \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n", + "generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n", + "bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n", + "test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", + "tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n", + "print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n" + ] + } + ], + "source": [ + "# Testing Huggingface's GPT2\n", + "test_sequence = [\"I enjoy walking with my cute dog\"]\n", + "generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n", + "gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", + "test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", + "tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n", + "print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('torchtext39')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/integration_tests/prototype/test_models.py b/test/integration_tests/prototype/test_models.py index 3a76f52cd6..5f012983d5 100644 --- a/test/integration_tests/prototype/test_models.py +++ b/test/integration_tests/prototype/test_models.py @@ -66,13 +66,22 @@ def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): model_input = transform(test_text) if model.encoder_only: - actual = model(model_input)["encoder_output"] + actual = model(encoder_tokens=model_input)["encoder_output"] + if not is_jit: + self._t5_get_encoder(model, model_input, actual) else: - actual = model(model_input)["decoder_output"] + actual = model(encoder_tokens=model_input)["decoder_output"] expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06) + def _t5_get_encoder(self, model, model_input, encoder_output): + encoder = model.get_encoder() + # Need to set the tgt_key_padding_mask to ensure the same results + encoder_padding_mask = model_input.eq(model.padding_idx) + output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"] + assert torch.all(output_from_get_encoder.eq(encoder_output)) + @nested_params(["jit", "not_jit"]) def test_t5_model(self, name) -> None: configuration, type = self.model_name.split("_") @@ -93,7 +102,8 @@ def test_t5_model(self, name) -> None: ], ) class TestT5Wrapper(TorchtextTestCase): - @parameterized.expand(["jit", "not_jit"]) + # No longer Torchscriptable + @parameterized.expand(["no_jit"]) def test_t5_wrapper(self, name) -> None: configuration = self.configuration test_text = ["translate English to French: I want to eat pizza for dinner."] @@ -113,7 +123,8 @@ def test_t5_wrapper(self, name) -> None: class TestT5WrapperCheckpoint(TorchtextTestCase): - @parameterized.expand(["jit", "not_jit"]) + # No longer Torchscriptable + @parameterized.expand(["no_jit"]) def test_t5_wrapper_checkpoint(self, name) -> None: test_text = ["translate English to French: I want to eat pizza for dinner."] expected_text = ["Je veux manger de la pizza pour le dîner."] @@ -127,7 +138,7 @@ def test_t5_wrapper_checkpoint(self, name) -> None: padding_idx=0, ) model = T5Wrapper( - checkpoint="https://download.pytorch.org/models/text/t5.base.generation.pt", + checkpoint="https://download.pytorch.org/models/text/t5.base.generation.v2.pt", t5_config=config, transform=transform, freeze_model=True, diff --git a/test/torchtext_unittest/prototype/test_generate.py b/test/torchtext_unittest/prototype/test_generate.py new file mode 100644 index 0000000000..7b8e0bf287 --- /dev/null +++ b/test/torchtext_unittest/prototype/test_generate.py @@ -0,0 +1,54 @@ +from unittest.mock import patch + +import torch +from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil +from torchtext.prototype.models import T5_BASE_GENERATION +from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase + + +class TestGenerationUtil(TorchtextTestCase): + def setUp(self) -> None: + super().setUp() + t5_base = T5_BASE_GENERATION + self.transform = t5_base.transform() + self.model = t5_base.get_model() + self.model.eval() + # Examples taken from T5 Paper and Huggingface + self.inputs = self.transform( + [ + "summarize: studies have shown that owning a dog is good for you", + "translate English to German: That is good.", + "cola sentence: The course is jumping well.", + "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", + "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", + ] + ) + torch.manual_seed(0) + + def test_greedy_generate_with_t5(self) -> None: + generation_model = GenerationUtil(self.model) + + tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30) + generated_text = self.transform.decode(tokens.tolist()) + + expected_generated_text = [ + "a dog is good for you, according to studies . owning a dog is good for you, according to studies .", + "Das ist gut.", + "acceptable", + "4.0", + "mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", + ] + + self.assertEqual(generated_text, expected_generated_text) + + def test_generate_errors_with_incorrect_beams(self) -> None: + generation_model = GenerationUtil(self.model, is_encoder_decoder=True) + + with self.assertRaises(ValueError): + generation_model.generate(self.inputs, num_beams=0) + + @patch("logging.Logger.warning") + def test_warns_when_no_max_len_provided(self, mock) -> None: + generation_model = GenerationUtil(self.model) + generation_model.generate(self.inputs) + mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") diff --git a/torchtext/prototype/generate.py b/torchtext/prototype/generate.py new file mode 100644 index 0000000000..53e1e003be --- /dev/null +++ b/torchtext/prototype/generate.py @@ -0,0 +1,141 @@ +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +logger = logging.getLogger(__name__) + +DEFAULT_MAX_SEQ_LEN = 256 + + +class GenerationUtil: + """Wrapper to provide generation utils for encoder/decoder models and decoder models. + + Example: + >>> model = T5_BASE_GENERATION.get_model() + >>> generative_model = GenerationUtil(model=model) + >>> generative_model.generate(input_ids, num_beams=1, max_len=100) + + The wrapper can work with *any* model as long as it meets the following requirements: + 1. Is an encoder/decoder or decoder based model. + 2. Includes a `get_encoder` method (if applicable) and a `prepare_inputs_for_generation` method. + + This means that popular HuggingFace implementation of T5, Bart, and GPT-2 can all be used with these generation utils! + >>> from transformers import T5Model + >>> model = T5Model.from_pretrained("t5-base") + >>> generative_model = GenerationUtil(model=model, is_huggingface_model=True) + >>> generative_model.generate(input_ids, num_beams=1, max_len=100) + + More examples can be found in the `notebooks` directory of this repository. + """ + + def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None: + self.model = model + self.is_encoder_decoder = is_encoder_decoder + self.is_huggingface_model = is_huggingface_model + + def _prepare_decoder_ids_for_generation( + self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs + ): + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + return model_kwargs.pop("decoder_input_ids") + else: + return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx + + def greedy_search( + self, input_ids: torch.Tensor, max_len: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs + ) -> torch.Tensor: + """Greedy search decoding for text generation. Takes the most likely next token every time. + + Inputs: + input_ids (Tensor): Text prompt(s) for greedy generation. + max_len (int): Max length to generate responses. + eos_idx (int): End of sequence index. + pad_idx (int): Padding index. + **model_kwargs + + Returns: + Batch of sequences decoded by greedy search. + """ + unfinished_sequences = torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long) + + while True: + model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) + if self.is_huggingface_model: + model_inputs["return_dict"] = True + model_inputs["output_hidden_states"] = True + + # Get model output + outputs = self.model(**model_inputs) + output_key = "logits" if self.is_huggingface_model else "decoder_output" + decoder_output = outputs[output_key] + + # Calculate probabilities and take the most likely next token + probs = F.log_softmax(decoder_output[:, -1], dim=-1) + _, next_tokens = torch.topk(probs, 1) + + # For any finished sequences, padding idx should be the last token + if eos_idx is not None: + if pad_idx is not None: + next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences) + + # Append the next tokens to the previous tokens + input_ids = torch.cat([input_ids, next_tokens], dim=-1) + + if eos_idx is not None: + unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long()) + + # Stop iterating once all sequences are finished or exceed the max_len + if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_len: + break + + return input_ids + + def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor: + raise NotImplementedError() + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + num_beams: Optional[int] = None, + max_len: Optional[int] = None, + pad_idx: int = 0, + eos_idx: int = 1, + ) -> torch.Tensor: + """Generation method. + + `num_beams` == 1 or `num_beams` is None -> greedy search + `num_beams` > 1 -> beam search + + Args: + input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation. + num_beams (int): If provided, specifies the number of beams to use in beam search generation. + max_len (int): Max length to generate responses. + pad_idx (int): Padding index. Defaults to 0. + eos_idx (int): End of sequence index. Defaults to 1. + + Returns: + Tensor of Tensors containing output sequences as ids. + + `Note`: If one beam is provided or no beams are specified, the generation method will default to greedy search. + """ + model_kwargs = {} + + if self.is_encoder_decoder: + encoder = self.model.get_encoder() + model_kwargs["encoder_outputs"] = encoder(inputs) + inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs) + + if max_len is None: + # Too hard to try to figure out the exact max_seq_length for each model + logger.warning(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.") + max_len = DEFAULT_MAX_SEQ_LEN + + if num_beams == 1 or num_beams is None: + return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs) + elif num_beams > 1: + return self.beam_search(inputs, num_beams, max_len) + else: + raise ValueError("`num_beams` must be >= 1.") diff --git a/torchtext/prototype/models/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py index d23b5b8308..d55b1ad591 100644 --- a/torchtext/prototype/models/t5/__init__.py +++ b/torchtext/prototype/models/t5/__init__.py @@ -1,19 +1,19 @@ from .bundler import ( - T5_BASE_ENCODER, + T5_11B, + T5_11B_ENCODER, + T5_11B_GENERATION, + T5_3B, + T5_3B_ENCODER, + T5_3B_GENERATION, T5_BASE, + T5_BASE_ENCODER, T5_BASE_GENERATION, - T5_SMALL_ENCODER, - T5_SMALL, - T5_SMALL_GENERATION, - T5_LARGE_ENCODER, T5_LARGE, + T5_LARGE_ENCODER, T5_LARGE_GENERATION, - T5_3B_ENCODER, - T5_3B, - T5_3B_GENERATION, - T5_11B_ENCODER, - T5_11B, - T5_11B_GENERATION, + T5_SMALL, + T5_SMALL_ENCODER, + T5_SMALL_GENERATION, T5Bundle, ) from .model import T5Conf, T5Model diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 65c94dc63e..3a53f946eb 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -176,7 +176,8 @@ def build_model_from_huggingface_ckpt( t5_model_state_dict = { "token_embeddings.weight": hf_weights["shared.weight"], - "norm1.weight": hf_weights["encoder.final_layer_norm.weight"], + "encoder.token_embeddings.weight": hf_weights["shared.weight"], + "encoder.norm.weight": hf_weights["encoder.final_layer_norm.weight"], "encoder.layers.0.self_attn.relative_attention_bias.weight": hf_weights[ "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" ], @@ -210,7 +211,7 @@ def build_model_from_huggingface_ckpt( # Convert decoder layers if model is encoder-decoder if not config.encoder_only: - t5_model_state_dict["norm2.weight"] = hf_weights["decoder.final_layer_norm.weight"] + t5_model_state_dict["decoder.norm.weight"] = hf_weights["decoder.final_layer_norm.weight"] t5_model_state_dict["decoder.layers.0.self_attn.relative_attention_bias.weight"] = hf_weights[ "decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" ] @@ -331,7 +332,7 @@ def config(self) -> T5Conf: """ T5_BASE_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.v2.pt"), _config=T5Conf(encoder_only=True), transform=lambda: T5Transform( urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), @@ -344,7 +345,7 @@ def config(self) -> T5Conf: T5_BASE_ENCODER.__doc__ = ENCODER_DOC.format("BASE", "base") T5_BASE = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.base.v2.pt"), _config=T5Conf(encoder_only=False), transform=lambda: T5Transform( urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), @@ -357,7 +358,7 @@ def config(self) -> T5Conf: T5_BASE.__doc__ = MODEL_DOC.format("BASE", "base") T5_BASE_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.base.generation.v2.pt"), _config=T5Conf(encoder_only=False, linear_head=True), transform=lambda: T5Transform( urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), @@ -370,7 +371,7 @@ def config(self) -> T5Conf: T5_BASE_GENERATION.__doc__ = GENERATION_DOC.format("BASE", "base") T5_SMALL_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.small.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.small.encoder.v2.pt"), _config=T5Conf( encoder_only=True, embedding_dim=512, @@ -391,7 +392,7 @@ def config(self) -> T5Conf: T5_SMALL = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.small.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.small.v2.pt"), _config=T5Conf( encoder_only=False, embedding_dim=512, @@ -411,7 +412,7 @@ def config(self) -> T5Conf: T5_SMALL.__doc__ = MODEL_DOC.format("SMALL", "small") T5_SMALL_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.small.generation.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.small.generation.v2.pt"), _config=T5Conf( encoder_only=False, linear_head=True, @@ -432,7 +433,7 @@ def config(self) -> T5Conf: T5_SMALL_GENERATION.__doc__ = GENERATION_DOC.format("SMALL", "small") T5_LARGE_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.large.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.large.encoder.v2.pt"), _config=T5Conf( encoder_only=True, embedding_dim=1024, @@ -452,7 +453,7 @@ def config(self) -> T5Conf: T5_LARGE_ENCODER.__doc__ = ENCODER_DOC.format("LARGE", "large") T5_LARGE = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.large.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.large.v2.pt"), _config=T5Conf( encoder_only=False, embedding_dim=1024, @@ -472,7 +473,7 @@ def config(self) -> T5Conf: T5_LARGE.__doc__ = MODEL_DOC.format("LARGE", "large") T5_LARGE_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.large.generation.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.large.generation.v2.pt"), _config=T5Conf( encoder_only=False, linear_head=True, @@ -493,7 +494,7 @@ def config(self) -> T5Conf: T5_LARGE_GENERATION.__doc__ = GENERATION_DOC.format("LARGE", "large") T5_3B_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.3b.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.3b.encoder.v2.pt"), _config=T5Conf( encoder_only=True, embedding_dim=1024, @@ -514,7 +515,7 @@ def config(self) -> T5Conf: T5_3B_ENCODER.__doc__ = ENCODER_DOC.format("3B", "3B") T5_3B = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.3b.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.3b.v2.pt"), _config=T5Conf( encoder_only=False, embedding_dim=1024, @@ -535,7 +536,7 @@ def config(self) -> T5Conf: T5_3B.__doc__ = MODEL_DOC.format("3B", "3B") T5_3B_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.3b.generation.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.3b.generation.v2.pt"), _config=T5Conf( encoder_only=False, linear_head=True, @@ -557,7 +558,7 @@ def config(self) -> T5Conf: T5_3B_GENERATION.__doc__ = GENERATION_DOC.format("3B", "3B") T5_11B_ENCODER = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.11b.encoder.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.11b.encoder.v2.pt"), _config=T5Conf( encoder_only=True, embedding_dim=1024, @@ -578,7 +579,7 @@ def config(self) -> T5Conf: T5_11B_ENCODER.__doc__ = ENCODER_DOC.format("11B", "11B") T5_11B = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.11b.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.11b.v2.pt"), _config=T5Conf( encoder_only=False, embedding_dim=1024, @@ -599,7 +600,7 @@ def config(self) -> T5Conf: T5_11B.__doc__ = MODEL_DOC.format("11B", "11B") T5_11B_GENERATION = T5Bundle( - _path=urljoin(_TEXT_BUCKET, "t5.11b.generation.pt"), + _path=urljoin(_TEXT_BUCKET, "t5.11b.generation.v2.pt"), _config=T5Conf( encoder_only=False, linear_head=True, diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 127aa44ee2..5d4c9e2991 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -1,14 +1,16 @@ +# logging library is not automatically supported by Torchscript +import warnings from dataclasses import dataclass -from typing import Dict, List, Optional, Union, Callable +from typing import Callable, Dict, List, Optional, Union import torch import torch.nn as nn from torch import Tensor -from .modules import T5Encoder, T5Decoder, T5LayerNorm +from .modules import T5Decoder, T5Encoder -@dataclass +@dataclass(frozen=True) class T5Conf: encoder_only: bool = False linear_head: bool = False @@ -99,12 +101,10 @@ def __init__( layer_norm_eps=config.layer_norm_eps, relative_attention_num_buckets=config.relative_attention_num_buckets, relative_attention_max_distance=config.relative_attention_max_distance, + token_embeddings=self.token_embeddings, device=device, dtype=dtype, ) - self.norm1 = T5LayerNorm(config.embedding_dim) - self.dropout1 = nn.Dropout(self.dropout) - self.dropout2 = nn.Dropout(self.dropout) if not config.encoder_only: self.decoder = T5Decoder( @@ -121,9 +121,6 @@ def __init__( device=device, dtype=dtype, ) - self.norm2 = T5LayerNorm(config.embedding_dim) - self.dropout3 = nn.Dropout(self.dropout) - self.dropout4 = nn.Dropout(self.dropout) else: self.decoder = None @@ -136,12 +133,30 @@ def __init__( for p in self.parameters(): p.requires_grad = False + def prepare_inputs_for_generation(self, input_ids, encoder_outputs): + return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs} + + @torch.jit.ignore + def get_encoder(self) -> T5Encoder: + return self.encoder + + @torch.jit.ignore + def get_decoder(self) -> Optional[T5Decoder]: + if self.decoder is None: + warnings.warn("Decoder is not set on this model.") + return self.decoder + def forward( self, - encoder_tokens: Tensor, + encoder_tokens: Optional[Tensor] = None, decoder_tokens: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, decoder_mask: Optional[Tensor] = None, + encoder_padding_mask: Optional[Tensor] = None, + decoder_padding_mask: Optional[Tensor] = None, + encoder_outputs: Optional[ + Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]] + ] = None, ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: @@ -167,25 +182,29 @@ def forward( encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder """ - encoder_padding_mask = encoder_tokens.eq(self.padding_idx) - encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens)) - encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder( - encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask - ) + if encoder_outputs is None: + assert encoder_tokens is not None, "If `encoder_outputs` is not specified, must provide `encoder_tokens`" - encoder_output = self.norm1(encoder_output) - encoder_output = self.dropout2(encoder_output) - encoder_hidden_states.append(encoder_output) + if encoder_padding_mask is None: + encoder_padding_mask = encoder_tokens.eq(self.padding_idx) - if not self.encoder_only: + encoder_outputs = self.encoder( + tgt=encoder_tokens, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask + ) + if not self.encoder_only: assert self.decoder is not None + assert encoder_outputs is not None + + encoder_output = encoder_outputs.get("encoder_output") + assert torch.jit.isinstance(encoder_output, Tensor) # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. if decoder_tokens is None: + batch_size = encoder_output.size()[0] + encoder_output_device = encoder_output.device decoder_tokens = ( - torch.ones((encoder_tokens.size(0), 1), device=encoder_tokens.device, dtype=torch.long) - * self.padding_idx + torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx ) if decoder_mask is None: @@ -194,12 +213,13 @@ def forward( decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) decoder_mask = decoder_mask.to(decoder_tokens.device, dtype=torch.bool) - decoder_padding_mask = decoder_tokens.eq(self.padding_idx) - # T5 implemention uses padding idx to start sequence. Want to ignore this when masking - decoder_padding_mask[:, 0] = False + if decoder_padding_mask is None: + decoder_padding_mask = decoder_tokens.eq(self.padding_idx) + # T5 implemention uses padding idx to start sequence. Want to ignore this when masking + decoder_padding_mask[:, 0] = False - decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens)) - decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder( + decoder_embeddings = self.token_embeddings(decoder_tokens) + decoder_outputs = self.decoder( decoder_embeddings, memory=encoder_output, tgt_mask=decoder_mask, @@ -208,9 +228,8 @@ def forward( memory_key_padding_mask=encoder_padding_mask, ) - decoder_output = self.norm2(decoder_output) - decoder_output = self.dropout4(decoder_output) - decoder_hidden_states.append(decoder_output) + decoder_output = decoder_outputs.get("decoder_output") + assert torch.jit.isinstance(decoder_output, Tensor) if self.linear_head: assert self.lm_head is not None @@ -219,28 +238,16 @@ def forward( # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 decoder_output = decoder_output * (self.embedding_dim ** -0.5) decoder_output = self.lm_head(decoder_output) + decoder_outputs["decoder_output"] = decoder_output - t5_output = { - "encoder_output": encoder_output, - "encoder_hidden_states": encoder_hidden_states, - "encoder_position_bias": encoder_position_bias, - "encoder_sa_scores": encoder_sa, - "decoder_output": decoder_output, - "decoder_hidden_states": decoder_hidden_states, - "decoder_position_bias": decoder_position_bias, - "decoder_sa_scores": decoder_sa, - "decoder_ca_scores": decoder_ca, - } - else: - t5_output = { - "encoder_output": encoder_output, - "encoder_hidden_states": encoder_hidden_states, - "encoder_position_bias": encoder_position_bias, - "encoder_sa_scores": encoder_sa, - } + encoder_outputs.update(decoder_outputs) + encoder_decoder_outputs = encoder_outputs assert torch.jit.isinstance( - t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]] + encoder_decoder_outputs, + Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]], ) - return t5_output + return encoder_decoder_outputs + + return encoder_outputs diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 7053f42cf5..56d72b3771 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -15,7 +15,7 @@ import math import warnings -from typing import List, Optional, Tuple, Union, Callable +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -773,7 +773,8 @@ def _ca_block( # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 class T5Encoder(nn.Module): - r"""T5Encoder is a stack of N encoder layers + """T5Encoder is a stack of N encoder layers. + Args: d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). @@ -787,7 +788,10 @@ class T5Encoder(nn.Module): relative_attention_num_buckets: Number of relative position buckets (default: 32) relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (defulat: 128) - Examples:: + token_embeddings (nn.Module): Embedding layer to be passed in the case that the input to `forward` + is not already embedded. + + Examples: >>> encoder = T5Encoder(d_model=768, nhead=12, num_layers=12) >>> tgt = torch.rand(32, 10, 512) >>> out = encoder(tgt) @@ -805,11 +809,12 @@ def __init__( layer_norm_eps: float = 1e-6, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, + token_embeddings: Optional[nn.Module] = None, device: Optional[torch.device] = None, dtype=None, ) -> None: super().__init__() - + self.token_embeddings = token_embeddings self.layers = nn.ModuleList( [ T5EncoderLayer( @@ -830,24 +835,44 @@ def __init__( ] ) self.num_layers = num_layers + self.norm = T5LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) def forward( self, - tgt: Tensor, + tgt: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]: + embedded_tgt: Optional[Tensor] = None, + ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]: r"""Pass the input (and masks) through the stack of encoder layers. + Args: - tgt: Input sequence to the encoder layer. (required). - Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence - length, and E is the model dimension. - tgt_mask: Attention mask for self-attention. (optional). + tgt (Optional[Tensor]): Tokenized input sequence to the encoder. + Must be batch first with shape (B, Ne) where B is the batch size and Ne is the + encoder input sequence length. + tgt_mask (Optional[Tensor]): Attention mask for self-attention. Must have shape (Nt, Nt). - tgt_key_padding_mask: Mask for the tgt keys per batch (optional). + tgt_key_padding_mask (Optional[Tensor]): Mask for the tgt keys per batch. Must have shape (B, Nt). + embedded_tgt (Optional[Tensor]): Embedded input sequence to the encoder layer. + Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence + length, and E is the model dimension. + *Note*: If you do not provide this `embedded_tgt`, you must have provided a `token_embedding` layer \ + in the initialization of the T5Encoder. + + Returns: + Tuple of last hidden layer, all hidden layers, position bias, and self-attention scores """ - output = tgt + # This keeps the encoder self-contained and easy to use individually + if embedded_tgt is None: + assert ( + self.token_embeddings is not None and tgt is not None + ), "Must provide `token_embeddings` and `tgt` if not providing already embedded tokens." + embedded_tgt = self.token_embeddings(tgt) + + output = self.dropout1(embedded_tgt) position_bias = None all_outputs = torch.jit.annotate(List[Tensor], []) all_sa_scores = torch.jit.annotate(List[Optional[Tensor]], []) @@ -861,7 +886,17 @@ def forward( ) all_sa_scores.append(sa_score) - return output, all_outputs, position_bias, all_sa_scores + output = self.norm(output) + output = self.dropout2(output) + + all_outputs.append(output) + + return { + "encoder_output": output, + "encoder_hidden_states": all_outputs, + "encoder_position_bias": position_bias, + "encoder_sa_scores": all_sa_scores, + } # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 @@ -923,20 +958,23 @@ def __init__( for i in range(num_layers) ] ) + self.norm = T5LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) self.num_layers = num_layers def forward( self, - tgt: Tensor, + embedded_tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]], List[Optional[Tensor]]]: + ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]], List[Optional[Tensor]]]]: r"""Pass the inputs (and masks) through the stack of decoder layers. Args: - tgt: Input sequence to the decoder layer. (required). + embedded_tgt: Input sequence to the decoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence length, and E is the model dimension. memory: Sequence from the last layer of the encoder. (required). @@ -951,7 +989,8 @@ def forward( memory_key_padding_mask: Mask for the memory keys per batch (optional). Must have shape (B, Ns). """ - output = tgt + + output = self.dropout1(embedded_tgt) position_bias = None all_outputs = torch.jit.annotate(List[Tensor], []) all_sa_scores = torch.jit.annotate(List[Optional[Tensor]], []) @@ -970,4 +1009,15 @@ def forward( all_sa_scores.append(sa_score) all_ca_scores.append(ca_score) - return output, all_outputs, position_bias, all_sa_scores, all_ca_scores + output = self.norm(output) + output = self.dropout2(output) + + all_outputs.append(output) + + return { + "decoder_output": output, + "decoder_hidden_states": all_outputs, + "decoder_position_bias": position_bias, + "decoder_sa_scores": all_sa_scores, + "decoder_ca_scores": all_ca_scores, + } diff --git a/torchtext/prototype/models/t5/wrapper.py b/torchtext/prototype/models/t5/wrapper.py index 8fa38a065f..6027bf72e9 100644 --- a/torchtext/prototype/models/t5/wrapper.py +++ b/torchtext/prototype/models/t5/wrapper.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -5,14 +6,14 @@ import torch.nn.functional as F from torch import Tensor from torchtext.prototype.models import ( + T5_11B_GENERATION, + T5_3B_GENERATION, T5_BASE_GENERATION, - T5_SMALL_GENERATION, T5_LARGE_GENERATION, - T5_3B_GENERATION, - T5_11B_GENERATION, + T5_SMALL_GENERATION, + T5Bundle, T5Conf, T5Transform, - T5Bundle, ) @@ -47,6 +48,8 @@ def __init__( strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`) dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) """ + warnings.warn("`T5Wrapper` is being deprecated. Please use new `GenerationUtils`.", category=DeprecationWarning) + super().__init__() if configuration is None: @@ -135,15 +138,11 @@ def beam_search( return new_decoder_tokens, new_scores, new_incomplete_sentences def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max_seq_len: int = 512) -> Tensor: - # pass tokens through encoder bsz = encoder_tokens.size(0) + encoder = self.model.get_encoder() encoder_padding_mask = encoder_tokens.eq(self.model.padding_idx) - encoder_embeddings = self.model.dropout1(self.model.token_embeddings(encoder_tokens)) - encoder_output = self.model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0] - - encoder_output = self.model.norm1(encoder_output) - encoder_output = self.model.dropout2(encoder_output) + encoder_outputs = encoder(tgt=encoder_tokens, tgt_key_padding_mask=encoder_padding_mask) # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * self.model.padding_idx @@ -154,37 +153,21 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token for step in range(max_seq_len): - if step == 1: # duplicate and order encoder output so that each beam is treated as its own independent sequence + encoder_output = encoder_outputs.get("encoder_output") new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) new_order = new_order.to(encoder_tokens.device).long() encoder_output = encoder_output.index_select(0, new_order) + encoder_outputs["encoder_output"] = encoder_output encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) - # causal mask and padding mask for decoder sequence - tgt_len = decoder_tokens.shape[1] - decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) - decoder_mask = decoder_mask.to(torch.bool) - decoder_padding_mask = decoder_tokens.eq(self.model.padding_idx) - - # T5 implemention uses padding idx to start sequence. Want to ignore this when masking - decoder_padding_mask[:, 0] = False - # pass decoder sequence through decoder - decoder_embeddings = self.model.dropout3(self.model.token_embeddings(decoder_tokens)) - decoder_output = self.model.decoder( - decoder_embeddings, - memory=encoder_output, - tgt_mask=decoder_mask, - tgt_key_padding_mask=decoder_padding_mask, - memory_key_padding_mask=encoder_padding_mask, - )[0] - - decoder_output = self.model.norm2(decoder_output) - decoder_output = self.model.dropout4(decoder_output) - decoder_output = decoder_output * (self.model.embedding_dim ** -0.5) - decoder_output = self.model.lm_head(decoder_output) + decoder_output = self.model( + decoder_tokens=decoder_tokens, + encoder_padding_mask=encoder_padding_mask, + encoder_outputs=encoder_outputs, + ).get("decoder_output") decoder_tokens, scores, incomplete_sentences = self.beam_search( beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences @@ -203,7 +186,6 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max return decoder_tokens def forward(self, input_text: List[str], beam_size: int, max_seq_len: int) -> Union[List[str], str]: - model_input = self.transform(input_text) model_output_tensor = self.generate(encoder_tokens=model_input, beam_size=beam_size, max_seq_len=max_seq_len) model_output_list = torch.jit.annotate(List[List[int]], model_output_tensor.tolist())