diff --git a/test/asset/t5.base.generation.output.pt b/test/asset/t5.base.generation.output.pt new file mode 100644 index 0000000000..e6aa0dfaf3 Binary files /dev/null and b/test/asset/t5.base.generation.output.pt differ diff --git a/test/prototype/integration_tests/test_models.py b/test/prototype/integration_tests/test_models.py index 27fc313acc..f955f6612f 100644 --- a/test/prototype/integration_tests/test_models.py +++ b/test/prototype/integration_tests/test_models.py @@ -4,6 +4,7 @@ from torchtext.prototype.models import ( T5_BASE_ENCODER, T5_BASE, + T5_BASE_GENERATION, ) @@ -24,7 +25,7 @@ def _t5_model(self, t5_model, expected_asset_name, test_text): actual = model(model_input)["decoder_output"] expected = torch.load(expected_asset_path) - torch.testing.assert_close(actual, expected) + torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06) def test_t5_base_encoder_model(self): expected_asset_name = "t5.base.encoder.output.pt" @@ -35,3 +36,8 @@ def test_t5_base_model(self): expected_asset_name = "t5.base.output.pt" test_text = ["Hello world", "Attention rocks!"] self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text) + + def test_t5_base_generation_model(self): + expected_asset_name = "t5.base.generation.output.pt" + test_text = ["Hello world", "Attention rocks!"] + self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text) diff --git a/test/prototype/models/test_models.py b/test/prototype/models/test_models.py index ccd165230c..17d438088f 100644 --- a/test/prototype/models/test_models.py +++ b/test/prototype/models/test_models.py @@ -1,13 +1,16 @@ +import copy from unittest.mock import patch +import torch from test.common.torchtext_test_case import TorchtextTestCase +from torch.nn import functional as F class TestModels(TorchtextTestCase): def test_t5_bundler_build_model(self): from torchtext.prototype.models import T5Conf, T5Model, T5Bundle - # case: user provide encoder checkpoint state dict + # case: user provides encoder checkpoint state dict dummy_encoder_conf = T5Conf( encoder_only=True, vocab_size=10, @@ -15,12 +18,13 @@ def test_t5_bundler_build_model(self): ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) dummy_t5_encoder = T5Model(dummy_encoder_conf) t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict()) self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict()) - # case: user provide encoder-decoder checkpoint state dict + # case: user provides encoder-decoder checkpoint state dict dummy_t5_conf = T5Conf( encoder_only=False, vocab_size=10, @@ -28,41 +32,46 @@ def test_t5_bundler_build_model(self): ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) dummy_t5 = T5Model(dummy_t5_conf) t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict()) self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict()) - @patch("logging.Logger.warning") - def test_t5_bundler_get_model(self, mock): - from torchtext.prototype.models import T5Conf, T5Bundle - - # encoder-only - dummy_encoder_conf = T5Conf( - encoder_only=True, + # case: user provides checkpoint state dict for encoder-decoder with generation + dummy_t5_generation_conf = T5Conf( + encoder_only=False, + linear_head=True, vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) - encoder_bundle = T5Bundle(dummy_encoder_conf) - encoder_bundle.get_model(load_weights=False, freeze_model=True) - mock.assert_called_with( - "The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights." + dummy_t5_generation = T5Model(dummy_t5_generation_conf) + t5_generation_model = T5Bundle.build_model( + config=dummy_t5_generation_conf, checkpoint=dummy_t5_generation.state_dict() ) + self.assertEqual(t5_generation_model.state_dict(), dummy_t5_generation.state_dict()) - # encoder-decoder - dummy_t5_conf = T5Conf( + @patch("logging.Logger.warning") + def test_t5_bundler_get_model(self, mock): + from torchtext.prototype.models import T5Conf, T5Bundle + + # encoder-decoder with generation + dummy_t5_generation_conf = T5Conf( encoder_only=False, + linear_head=True, vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) - t5_bundle = T5Bundle(dummy_t5_conf) - t5_bundle.get_model(load_weights=False, freeze_model=True) + t5_generation_bundle = T5Bundle(dummy_t5_generation_conf) + t5_generation_bundle.get_model(load_weights=False, freeze_model=True) mock.assert_called_with( "The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights." ) @@ -79,6 +88,7 @@ def test_t5_bundler_raise_checkpoint(self): ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) T5Bundle.build_model( config=dummy_encoder_conf, @@ -95,6 +105,7 @@ def test_t5_bundler_raise_checkpoint(self): ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) T5Bundle.build_model( config=dummy_t5_conf, @@ -102,6 +113,24 @@ def test_t5_bundler_raise_checkpoint(self): checkpoint=1, ) + # encoder-decoder with generation + with self.assertRaises(TypeError): + dummy_t5_generation_conf = T5Conf( + encoder_only=False, + linear_head=True, + vocab_size=10, + embedding_dim=16, + ffn_dimension=64, + num_attention_heads=2, + num_encoder_layers=2, + num_decoder_layers=2, + ) + T5Bundle.build_model( + config=dummy_t5_generation_conf, + freeze_model=True, + checkpoint=1, + ) + def test_t5_bundler_conf_property(self): from torchtext.prototype.models import T5Conf, T5Bundle @@ -112,6 +141,43 @@ def test_t5_bundler_conf_property(self): ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2, + num_decoder_layers=2, ) t5_bundle = T5Bundle(dummy_t5_conf) self.assertTrue(isinstance(t5_bundle.config, T5Conf)) + + def test_t5_bundler_train(self): + from torch.optim import SGD + from torchtext.prototype.models import T5Conf, T5Model, T5Bundle + + def _train(model): + optim = SGD(model.parameters(), lr=1) + model_input = torch.tensor([[1, 2, 3, 4, 5]]) + target = torch.tensor([1]) + output = model(model_input)["decoder_output"] + logits = F.log_softmax(output[:, -1], dim=-1) + loss = F.cross_entropy(logits, target) + loss.backward() + optim.step() + + dummy_conf = T5Conf( + encoder_only=False, + linear_head=True, + vocab_size=10, + embedding_dim=16, + ffn_dimension=64, + num_attention_heads=2, + num_encoder_layers=2, + num_decoder_layers=2, + training=True, + ) + dummy_model = T5Model(dummy_conf) + model = T5Bundle.build_model( + config=dummy_conf, + freeze_model=False, + checkpoint=dummy_model.state_dict(), + ) + current_state_dict = copy.deepcopy(model.state_dict()) + + _train(model) + self.assertNotEqual(model.state_dict(), current_state_dict) diff --git a/torchtext/prototype/models/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py index 45c0de5e04..69bb75aef6 100644 --- a/torchtext/prototype/models/t5/__init__.py +++ b/torchtext/prototype/models/t5/__init__.py @@ -1,6 +1,7 @@ from .bundler import ( T5_BASE_ENCODER, T5_BASE, + T5_BASE_GENERATION, T5Bundle, ) from .model import T5Conf, T5Model @@ -12,5 +13,6 @@ "T5Bundle", "T5_BASE_ENCODER", "T5_BASE", + "T5_BASE_GENERATION", "T5Transform", ] diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 0a52007aae..c4c1d78fb6 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -39,6 +39,19 @@ class T5Bundle: >>> output.shape torch.Size([2, 1, 768]) + Example - Pretrained base t5 model for generation + >>> import torch, torchtext + >>> import torch.nn.functional as F + >>> t5_base_generation = torchtext.prototype.models.T5_BASE_GENERATION + >>> transform = t5_base_generation.transform() + >>> input_seq = ["Hello world", "Attention rocks!"] + >>> model = t5_base_generation.get_model() + >>> model_input = transform(input_seq) + >>> output = model(model_input)['decoder_output'] + >>> logits = F.log_softmax(output[:,-1], dim=-1) + >>> logits.shape + torch.Size([2, 1, 32128]) + Example - User-specified configuration and checkpoint >>> from torchtext.prototype.models import T5Conf, T5Bundle >>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt" @@ -137,7 +150,8 @@ def config(self) -> T5Conf: ) T5_BASE_ENCODER.__doc__ = """ - T5 Encoder with Base configuration + T5_BASE_ENCODER is an encoder-only model from a pre-trained T5 model with the base configuration.. + It returns the normalized output from the final layer of the encoder. The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `. It introduces a unified framework that converts text-based @@ -167,7 +181,39 @@ def config(self) -> T5Conf: ) T5_BASE.__doc__ = """ - T5 Encoder-Decoder with Base configuration + T5_BASE is an encoder-decoder model from a pre-trained T5 model with the base configuration. + It returns the normalized output from the final layer of the decoder. + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `. It introduces a unified framework that converts text-based + language problems, such as translation, question-answering, and summarization, into a text-to-text format. The + Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task, + and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version + of the canonical Transformer architecture. + + Originally published by the authors of T5 under Apache License, Version 2.0 + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. + """ + +T5_BASE_GENERATION = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.generation.pt"), + _config=T5Conf(encoder_only=False, linear_head=True), + transform=lambda: T5Transform( + urljoin(_TEXT_BUCKET, "t5_tokenizer_base.model"), + max_seq_len=512, + eos_idx=1, + padding_idx=0, + ), +) + +T5_BASE_GENERATION.__doc__ = """ + T5_BASE_GENERATION is an encoder-decoder model from a pre-trained T5 model with the base configuration. + It returns the output of the final layer of the decoder after passing through a linear layer to project the hidden states to + the model vocabulary. This output can then be used for language generation. The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `. It introduces a unified framework that converts text-based diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 41ffb12b4e..6a5349ce53 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -11,6 +11,7 @@ @dataclass class T5Conf: encoder_only: bool = False + linear_head: bool = False embedding_dim: int = 768 num_attention_heads: int = 12 num_encoder_layers: int = 12 @@ -35,7 +36,8 @@ class T5Model(nn.Module): Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Args: - config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) + config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (default=False). + config.linear_head: Whether or not a linear layer should be used to project the output of the decoder's last layer to the vocab (default=False). config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768). config.num_attention_heads: Number of heads in the multiheadattention models (default=12). config.num_encoder_layers: Number of encoder layers in the encoder (default=12). @@ -55,11 +57,12 @@ class T5Model(nn.Module): freeze: Indicates whether or not to freeze the model weights. (default: False) Examples: >>> from torchtext.prototype.models import T5Conf, T5Model - >>> t5_config = T5Conf(encoder_only=False) + >>> t5_config = T5Conf(encoder_only=False, linear_head=True) >>> t5_model = T5Model(t5_config) - >>> encoder_input = torch.rand((32, 10, 512)) - >>> decoder_input = torch.rand((32, 20, 512)) - >>> out = t5_model(encoder_input, decoder_input) + >>> encoder_input = torch.randint(0, t5_config.vocab_size, (32, 512)) + >>> out = t5_model(encoder_input)['decoder_output'] + >>> out.shape + torch.Size([32, 1, 32128]) """ def __init__( @@ -73,7 +76,9 @@ def __init__( assert isinstance(config, T5Conf) + self.config = config self.encoder_only = config.encoder_only + self.linear_head = config.linear_head self.padding_idx = config.padding_idx self.training = config.training self.dropout = config.dropout if config.training else 0.0 @@ -118,6 +123,9 @@ def __init__( self.dropout3 = nn.Dropout(self.dropout) self.dropout4 = nn.Dropout(self.dropout) + if config.linear_head: + self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False) + if freeze: for p in self.parameters(): p.requires_grad = False @@ -191,6 +199,13 @@ def forward( decoder_output = self.dropout4(decoder_output) decoder_hidden_states = decoder_hidden_states + (decoder_output,) + if self.linear_head: + # Rescale output before projecting on vocab. This happens when the encoder and decoder share the + # same word embeddings, which is always the case in our t5 implementation. + # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 + decoder_output = decoder_output * (self.config.embedding_dim ** -0.5) + decoder_output = self.lm_head(decoder_output) + t5_output = { "encoder_output": encoder_output, "encoder_hidden_states": encoder_hidden_states,