From 925903c3cb1e733a3b1f68be03c4d60c97ba112c Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 11 Oct 2022 06:54:47 -0700 Subject: [PATCH] Add padding_masks and tests for T5Model Summary: Added the following parameters to the `forward` method of the T5Model: * `encoder_padding_mask` * `decoder_padding_mask` These allow users to specifically mask out the padding of input sequences. This matches the implementation of Transformers in PyTorch core. Differential Revision: D40252794 fbshipit-source-id: cc0a81ff145db3459ac8a3793971c8de48f64cd7 --- .../prototype/models/test_models.py | 87 +++++++++++++++--- torchtext/prototype/models/t5/model.py | 88 ++++++++++++++----- 2 files changed, 142 insertions(+), 33 deletions(-) diff --git a/test/torchtext_unittest/prototype/models/test_models.py b/test/torchtext_unittest/prototype/models/test_models.py index 7d7fc9da66..33c34d7638 100644 --- a/test/torchtext_unittest/prototype/models/test_models.py +++ b/test/torchtext_unittest/prototype/models/test_models.py @@ -3,13 +3,12 @@ import torch from torch.nn import functional as F +from torchtext.prototype.models import T5Bundle, T5Conf, T5Model from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase class TestModels(TorchtextTestCase): def test_t5_bundler_build_model(self) -> None: - from torchtext.prototype.models import T5Conf, T5Model, T5Bundle - # case: user provides encoder checkpoint state dict dummy_encoder_conf = T5Conf( encoder_only=True, @@ -21,7 +20,9 @@ def test_t5_bundler_build_model(self) -> None: num_decoder_layers=2, ) dummy_t5_encoder = T5Model(dummy_encoder_conf) - t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict()) + t5_encoder_model = T5Bundle.build_model( + config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict() + ) self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict()) # case: user provides encoder-decoder checkpoint state dict @@ -35,7 +36,9 @@ def test_t5_bundler_build_model(self) -> None: num_decoder_layers=2, ) dummy_t5 = T5Model(dummy_t5_conf) - t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict()) + t5_model = T5Bundle.build_model( + config=dummy_t5_conf, checkpoint=dummy_t5.state_dict() + ) self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict()) # case: user provides checkpoint state dict for encoder-decoder with generation @@ -53,12 +56,12 @@ def test_t5_bundler_build_model(self) -> None: t5_generation_model = T5Bundle.build_model( config=dummy_t5_generation_conf, checkpoint=dummy_t5_generation.state_dict() ) - self.assertEqual(t5_generation_model.state_dict(), dummy_t5_generation.state_dict()) + self.assertEqual( + t5_generation_model.state_dict(), dummy_t5_generation.state_dict() + ) @patch("logging.Logger.warning") def test_t5_bundler_get_model(self, mock): - from torchtext.prototype.models import T5Conf, T5Bundle - # encoder-decoder with generation dummy_t5_generation_conf = T5Conf( encoder_only=False, @@ -77,8 +80,6 @@ def test_t5_bundler_get_model(self, mock): ) def test_t5_bundler_raise_checkpoint(self) -> None: - from torchtext.prototype.models import T5Conf, T5Bundle - # encoder-only with self.assertRaises(TypeError): dummy_encoder_conf = T5Conf( @@ -132,8 +133,6 @@ def test_t5_bundler_raise_checkpoint(self) -> None: ) def test_t5_bundler_conf_property(self) -> None: - from torchtext.prototype.models import T5Conf, T5Bundle - dummy_t5_conf = T5Conf( encoder_only=False, vocab_size=10, @@ -148,7 +147,6 @@ def test_t5_bundler_conf_property(self) -> None: def test_t5_bundler_train(self) -> None: from torch.optim import SGD - from torchtext.prototype.models import T5Conf, T5Model, T5Bundle def _train(model): optim = SGD(model.parameters(), lr=1) @@ -181,3 +179,68 @@ def _train(model): _train(model) self.assertNotEqual(model.state_dict(), current_state_dict) + + def test_t5_model_forward_with_encoder_mask_encoder_only(self) -> None: + dummy_conf = T5Conf( + encoder_only=True, + linear_head=True, + vocab_size=100, + embedding_dim=16, + ffn_dimension=64, + num_attention_heads=2, + num_encoder_layers=2, + num_decoder_layers=2, + training=False, + ) + dummy_model = T5Model(dummy_conf) + tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]]) + mask = tokens.eq(0) + + with torch.no_grad(): + output_with_mask = dummy_model( + encoder_tokens=tokens, encoder_padding_mask=mask + ) + output_no_mask = dummy_model(tokens) + + torch.testing.assert_close( + output_with_mask["encoder_output"], + output_no_mask["encoder_output"], + atol=1e-04, + rtol=2.5e-06, + ) + + def test_t5_model_forward_with_encoder_mask_encoder_decoder(self) -> None: + dummy_conf = T5Conf( + encoder_only=False, + linear_head=True, + vocab_size=100, + embedding_dim=16, + ffn_dimension=64, + num_attention_heads=2, + num_encoder_layers=2, + num_decoder_layers=2, + training=False, + ) + dummy_model = T5Model(dummy_conf) + enc_tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]]) + dec_tokens = torch.tensor([[6, 7, 8, 9, 10, 11, 0, 0, 0, 0]]) + enc_mask = enc_tokens.eq(0) + dec_mask = dec_tokens.eq(0) + + with torch.no_grad(): + output_with_mask = dummy_model( + encoder_tokens=enc_tokens, + encoder_padding_mask=enc_mask, + decoder_tokens=dec_tokens, + decoder_padding_mask=dec_mask, + ) + output_no_mask = dummy_model( + encoder_tokens=enc_tokens, decoder_tokens=dec_tokens + ) + + torch.testing.assert_close( + output_with_mask["decoder_output"], + output_no_mask["decoder_output"], + atol=1e-04, + rtol=2.5e-06, + ) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 2812af3c74..61e1dabe01 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union, Callable +from typing import Callable, Dict, List, Optional, Union import torch import torch.nn as nn from torch import Tensor -from .modules import T5Encoder, T5Decoder, T5LayerNorm +from .modules import T5Decoder, T5Encoder, T5LayerNorm @dataclass @@ -88,7 +88,9 @@ def __init__( self.device = device self.dtype = dtype - self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx) + self.token_embeddings = nn.Embedding( + config.vocab_size, config.embedding_dim, config.padding_idx + ) self.encoder = T5Encoder( d_model=config.embedding_dim, nhead=config.num_attention_heads, @@ -129,7 +131,9 @@ def __init__( self.decoder = None if config.linear_head: - self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False) + self.lm_head = nn.Linear( + config.embedding_dim, config.vocab_size, bias=False + ) else: self.lm_head = None @@ -140,23 +144,31 @@ def __init__( def forward( self, encoder_tokens: Tensor, - decoder_tokens: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, + encoder_padding_mask: Optional[Tensor] = None, + decoder_tokens: Optional[Tensor] = None, decoder_mask: Optional[Tensor] = None, - ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]: + decoder_padding_mask: Optional[Tensor] = None, + ) -> Dict[ + str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]] + ]: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: encoder_tokens: Tokenized input sequence to the encoder. Must be batch first with shape (B, Ne) where B is the batch size and Ne is the encoder input sequence length. (required). + encoder_mask: Additive mask for the encoder input sequence. + Must have shape (Ne, Ne) (optional). + encoder_padding_mask: Padding mask for encoder input sequence. + Must have shape (B, Ne) (optional). decoder_tokens: Tokenized input sequence to the decoder. Must be batch first with shape (B, Nd) where B is the batch size and Nd is the decoder input sequence length. If None and model is encoder-decoder, will initialize decoder input sequence to begin with padding index. (optional). - encoder_mask: Self-attention mask for the encoder input sequence. - Must have shape (Ne, Ne) (optional). - decoder_mask: Self-attention mask for the decoder input sequence. + decoder_mask: Additive mask for the decoder input sequence. Must have shape (Nd, Nd) (optional). + decoder_padding_mask: Padding mask for decoder input sequence. + Must have shape (B, Ne) (optional). Returns: encoder_output: Output Tensor from the final layer of the encoder encoder_hidden_states: Tuple of output Tensors from each layer of the encoder @@ -168,10 +180,24 @@ def forward( encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder """ - encoder_padding_mask = encoder_tokens.eq(self.padding_idx) + if encoder_padding_mask is None: + encoder_padding_mask = encoder_tokens.eq(self.padding_idx) + + batch_size = encoder_tokens.shape[0] + seq_len = encoder_tokens.shape[1] + + assert encoder_padding_mask.shape == (batch_size, seq_len) + encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens)) - encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder( - encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask + ( + encoder_output, + encoder_hidden_states, + encoder_position_bias, + encoder_sa, + ) = self.encoder( + encoder_embeddings, + tgt_mask=encoder_mask, + tgt_key_padding_mask=encoder_padding_mask, ) encoder_output = self.norm1(encoder_output) @@ -184,20 +210,34 @@ def forward( # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. if decoder_tokens is None: - decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx + decoder_tokens = ( + torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) + * self.padding_idx + ) + + tgt_seq_len = decoder_tokens.shape[1] if decoder_mask is None: assert decoder_tokens is not None and decoder_tokens.dim() == 2 - tgt_len = decoder_tokens.shape[1] - decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) + decoder_mask = torch.triu( + torch.ones((tgt_seq_len, tgt_seq_len), dtype=torch.float64), + diagonal=1, + ) decoder_mask = decoder_mask.to(torch.bool) - decoder_padding_mask = decoder_tokens.eq(self.padding_idx) - # T5 implemention uses padding idx to start sequence. Want to ignore this when masking - decoder_padding_mask[:, 0] = False + if decoder_padding_mask is None: + decoder_padding_mask = decoder_tokens.eq(self.padding_idx) + # T5 implemention uses padding idx to start sequence. Want to ignore this when masking + decoder_padding_mask[:, 0] = False decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens)) - decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder( + ( + decoder_output, + decoder_hidden_states, + decoder_position_bias, + decoder_sa, + decoder_ca, + ) = self.decoder( decoder_embeddings, memory=encoder_output, tgt_mask=decoder_mask, @@ -215,7 +255,7 @@ def forward( # Rescale output before projecting on vocab. This happens when the encoder and decoder share the # same word embeddings, which is always the case in our t5 implementation. # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 - decoder_output = decoder_output * (self.embedding_dim ** -0.5) + decoder_output = decoder_output * (self.embedding_dim**-0.5) decoder_output = self.lm_head(decoder_output) t5_output = { @@ -238,7 +278,13 @@ def forward( } assert torch.jit.isinstance( - t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]] + t5_output, + Dict[ + str, + Union[ + Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]] + ], + ], ) return t5_output