From 332185485e803163145d554106f7cba8bae4c34e Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 8 Aug 2022 10:32:38 -0400 Subject: [PATCH 1/6] type annotate device --- torchtext/prototype/models/t5/model.py | 2 +- torchtext/prototype/models/t5/modules.py | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 6a5349ce53..93f4d835e2 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -69,7 +69,7 @@ def __init__( self, config: T5Conf, freeze: bool = False, - device=None, + device: Optional[torch.device] = None, dtype=None, ) -> None: super().__init__() diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 77b6733de4..0c1eb48726 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -26,14 +26,14 @@ class T5MultiheadAttention(nn.MultiheadAttention): def __init__( self, - embed_dim, - num_heads, - is_decoder=False, - dropout=0.0, - bias=False, - kdim=None, - vdim=None, - device=None, + embed_dim: int, + num_heads: int, + is_decoder: bool = False, + dropout: float = 0.0, + bias: bool = False, + kdim: int = None, + vdim: int = None, + device: Optional[torch.device] = None, dtype=None, ) -> None: r""" @@ -354,7 +354,7 @@ def _compute_bias( relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, bidirectional: bool = True, - device=None, + device: Optional[torch.device] = None, ) -> Tensor: """Compute binned relative position bias""" if device is None: @@ -498,7 +498,7 @@ def __init__( relative_attention_max_distance: int = 128, compute_relative_attention_bias: bool = False, relative_attention_bias: Optional[Tensor] = None, - device=None, + device: Optional[torch.device] = None, dtype=None, ) -> None: super().__init__() @@ -659,7 +659,7 @@ def __init__( layer_norm_eps: float = 1e-6, relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, - device=None, + device: Optional[torch.device] = None, dtype=None, ) -> None: super().__init__() From 59d0e9e44f127e60d4f1b1e2675ce6151c429360 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 8 Aug 2022 10:49:48 -0400 Subject: [PATCH 2/6] refactor relative_attention_bias --- torchtext/prototype/models/t5/modules.py | 58 ++++++++++-------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 0c1eb48726..e536502e1c 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -33,6 +33,9 @@ def __init__( bias: bool = False, kdim: int = None, vdim: int = None, + compute_relative_attention_bias=False, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, device: Optional[torch.device] = None, dtype=None, ) -> None: @@ -54,6 +57,13 @@ def __init__( self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) self.register_parameter("in_proj_weight", None) + self.compute_relative_attention_bias = compute_relative_attention_bias + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + + if compute_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(relative_attention_num_buckets, num_heads) + def forward( self, query: Tensor, @@ -63,10 +73,6 @@ def forward( need_weights: bool = True, attn_mask: Optional[Tensor] = None, average_attn_weights: bool = False, - compute_relative_attention_bias=False, - relative_attention_num_buckets=32, - relative_attention_max_distance=128, - relative_attention_bias: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: r""" @@ -124,10 +130,6 @@ def forward( query, key, value, - compute_relative_attention_bias=compute_relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, - relative_attention_bias=relative_attention_bias, position_bias=position_bias, key_padding_mask=key_padding_mask, need_weights=need_weights, @@ -142,10 +144,6 @@ def _t5_multi_head_attention_forward( query: Tensor, key: Tensor, value: Tensor, - compute_relative_attention_bias: bool, - relative_attention_num_buckets: Optional[int], - relative_attention_max_distance: Optional[int], - relative_attention_bias: Optional[Tensor], position_bias: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, @@ -255,7 +253,7 @@ def _t5_multi_head_attention_forward( # NOTE: Modification to torch.nn.functional._multi_head_attention_forward to incorporate relative attention bias if position_bias is None: - if not compute_relative_attention_bias: + if not self.compute_relative_attention_bias: position_bias = torch.zeros( (self.num_heads, tgt_len, src_len), device=k.device, dtype=k.dtype ).unsqueeze(0) @@ -263,9 +261,6 @@ def _t5_multi_head_attention_forward( position_bias = self._compute_bias( tgt_len, src_len, - relative_attention_bias, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, bidirectional=(not self.is_decoder), device=k.device, ) @@ -350,25 +345,22 @@ def _compute_bias( self, query_length: int, key_length: int, - relative_attention_bias: Tensor, - relative_attention_num_buckets: int = 32, - relative_attention_max_distance: int = 128, bidirectional: bool = True, device: Optional[torch.device] = None, ) -> Tensor: """Compute binned relative position bias""" if device is None: - device = relative_attention_bias.weight.device + device = self.relative_attention_bias.weight.device context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=bidirectional, - num_buckets=relative_attention_num_buckets, - max_distance=relative_attention_max_distance, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, ) - values = relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values @@ -395,7 +387,7 @@ def _relative_position_bucket( """ relative_buckets = 0 if bidirectional: - num_buckets //= 2 + num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: @@ -476,7 +468,6 @@ class T5Layer(nn.Module): compute_relative_attention_bias: Whether or not the relative position embeddings need to be computed. Typically occurs in the first layer of encoder/decoder and resulting position embeddings are returned to be passed up to higher layers. (default: False) - relative_attention_bias: nn.Embeding object used to compute relative position embeddings. (default: None) Examples:: >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12) @@ -497,7 +488,6 @@ def __init__( relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, compute_relative_attention_bias: bool = False, - relative_attention_bias: Optional[Tensor] = None, device: Optional[torch.device] = None, dtype=None, ) -> None: @@ -507,10 +497,17 @@ def __init__( self.compute_relative_attention_bias = compute_relative_attention_bias self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance - self.relative_attention_bias = relative_attention_bias self.self_attn = T5MultiheadAttention( - d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype + d_model, + nhead, + is_decoder=is_decoder, + dropout=dropout, + compute_relative_attention_bias=compute_relative_attention_bias, + relative_attention_num_buckets=relative_attention_num_buckets, + relative_attention_max_distance=relative_attention_max_distance, + device=device, + dtype=dtype, ) self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False) self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False) @@ -595,10 +592,6 @@ def _sa_block( attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=True, - compute_relative_attention_bias=self.compute_relative_attention_bias, - relative_attention_num_buckets=self.relative_attention_num_buckets, - relative_attention_max_distance=self.relative_attention_max_distance, - relative_attention_bias=self.relative_attention_bias, position_bias=position_bias, ) @@ -677,7 +670,6 @@ def __init__( relative_attention_num_buckets, relative_attention_max_distance, compute_relative_attention_bias=True if i == 0 else False, - relative_attention_bias=nn.Embedding(relative_attention_num_buckets, nhead) if i == 0 else None, device=device, dtype=dtype, ) From 22809fa5afaa508e06280d9472b957b43591ea02 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 8 Aug 2022 14:36:17 -0400 Subject: [PATCH 3/6] breaking out encoder and decoder layer and stacks --- torchtext/prototype/models/t5/model.py | 31 ++- torchtext/prototype/models/t5/modules.py | 278 ++++++++++++++++++++--- 2 files changed, 260 insertions(+), 49 deletions(-) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 93f4d835e2..63bfcc8109 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -1,11 +1,11 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union, Callable +from typing import Dict, List, Optional, Union, Callable import torch import torch.nn as nn from torch import Tensor -from .modules import T5Stack, T5LayerNorm +from .modules import T5Encoder, T5Decoder, T5LayerNorm @dataclass @@ -77,6 +77,7 @@ def __init__( assert isinstance(config, T5Conf) self.config = config + self.embedding_dim = config.embedding_dim self.encoder_only = config.encoder_only self.linear_head = config.linear_head self.padding_idx = config.padding_idx @@ -86,8 +87,7 @@ def __init__( self.dtype = dtype self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx) - self.encoder = T5Stack( - is_decoder=False, + self.encoder = T5Encoder( d_model=config.embedding_dim, nhead=config.num_attention_heads, num_layers=config.num_encoder_layers, @@ -105,8 +105,7 @@ def __init__( self.dropout2 = nn.Dropout(self.dropout) if not config.encoder_only: - self.decoder = T5Stack( - is_decoder=True, + self.decoder = T5Decoder( d_model=config.embedding_dim, nhead=config.num_attention_heads, num_layers=config.num_decoder_layers, @@ -122,9 +121,13 @@ def __init__( self.norm2 = T5LayerNorm(config.embedding_dim) self.dropout3 = nn.Dropout(self.dropout) self.dropout4 = nn.Dropout(self.dropout) + else: + self.decoder = None if config.linear_head: self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False) + else: + self.lm_head = None if freeze: for p in self.parameters(): @@ -136,7 +139,7 @@ def forward( decoder_tokens: Tensor = None, encoder_mask: Optional[Tensor] = None, decoder_mask: Optional[Tensor] = None, - ) -> Dict[str, Union[Tensor, Tuple[Tensor]]]: + ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]: r"""Pass the inputs (and mask) through the decoder layer in turn. Args: encoder_tokens: Tokenized input sequence to the encoder. @@ -163,23 +166,26 @@ def forward( """ encoder_padding_mask = encoder_tokens.eq(self.padding_idx) encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens)) - encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder( + encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder( encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask ) encoder_output = self.norm1(encoder_output) encoder_output = self.dropout2(encoder_output) - encoder_hidden_states = encoder_hidden_states + (encoder_output,) + encoder_hidden_states.append(encoder_output) if not self.encoder_only: + assert self.decoder is not None + # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. if decoder_tokens is None: decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx if decoder_mask is None: tgt_len = decoder_tokens.shape[1] - decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() + decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) + decoder_mask = decoder_mask.to(torch.bool) decoder_padding_mask = decoder_tokens.eq(self.padding_idx) # T5 implemention uses padding idx to start sequence. Want to ignore this when masking @@ -197,13 +203,14 @@ def forward( decoder_output = self.norm2(decoder_output) decoder_output = self.dropout4(decoder_output) - decoder_hidden_states = decoder_hidden_states + (decoder_output,) + decoder_hidden_states.append(decoder_output) if self.linear_head: + assert self.lm_head is not None # Rescale output before projecting on vocab. This happens when the encoder and decoder share the # same word embeddings, which is always the case in our t5 implementation. # See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661 - decoder_output = decoder_output * (self.config.embedding_dim ** -0.5) + decoder_output = decoder_output * (self.embedding_dim ** -0.5) decoder_output = self.lm_head(decoder_output) t5_output = { diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index e536502e1c..86dde3bac5 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -15,7 +15,7 @@ import math import warnings -from typing import Optional, Tuple, Union, Callable +from typing import List, Optional, Tuple, Union, Callable import torch import torch.nn as nn @@ -63,6 +63,8 @@ def __init__( if compute_relative_attention_bias: self.relative_attention_bias = nn.Embedding(relative_attention_num_buckets, num_heads) + else: + self.relative_attention_bias = None def forward( self, @@ -349,6 +351,7 @@ def _compute_bias( device: Optional[torch.device] = None, ) -> Tensor: """Compute binned relative position bias""" + assert self.relative_attention_bias is not None if device is None: device = self.relative_attention_bias.weight.device context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] @@ -385,7 +388,7 @@ def _relative_position_bucket( Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ - relative_buckets = 0 + relative_buckets = torch.zeros(relative_position.shape, dtype=torch.long) if bidirectional: num_buckets = num_buckets // 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets @@ -445,7 +448,7 @@ def forward(self, hidden_states: Tensor) -> Tensor: # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622 -class T5Layer(nn.Module): +class T5EncoderLayer(nn.Module): r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. This T5 layer is based on the paper: "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". @@ -478,7 +481,6 @@ class T5Layer(nn.Module): def __init__( self, - is_decoder: bool, d_model: int, nhead: int, dim_feedforward: int = 3072, @@ -493,7 +495,6 @@ def __init__( ) -> None: super().__init__() - self.is_decoder = is_decoder self.compute_relative_attention_bias = compute_relative_attention_bias self.relative_attention_num_buckets = relative_attention_num_buckets self.relative_attention_max_distance = relative_attention_max_distance @@ -501,7 +502,7 @@ def __init__( self.self_attn = T5MultiheadAttention( d_model, nhead, - is_decoder=is_decoder, + is_decoder=False, dropout=dropout, compute_relative_attention_bias=compute_relative_attention_bias, relative_attention_num_buckets=relative_attention_num_buckets, @@ -517,13 +518,6 @@ def __init__( self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) - if is_decoder: - self.cross_attn = T5MultiheadAttention( - d_model, nhead, is_decoder=is_decoder, dropout=dropout, device=device, dtype=dtype - ) - self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps) - self.dropout4 = nn.Dropout(dropout) - if isinstance(activation, str): assert activation in ( "relu", @@ -539,13 +533,10 @@ def __init__( def forward( self, tgt: Tensor, - memory: Tensor, tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: r"""Pass the inputs (and mask) through the encoder/decoder layer. Args: tgt: Input sequence to the encoder/decoder layer. (required). @@ -570,12 +561,9 @@ def forward( x = tgt sa_out, position_bias, sa_scores = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, position_bias) x = x + sa_out - if self.is_decoder: - ca_out, ca_scores = self._ca_block(self.norm3(x), memory, memory_mask, memory_key_padding_mask) - x = x + ca_out x = x + self._ff_block(self.norm2(x)) - return x, position_bias, sa_scores, ca_scores if self.is_decoder else None + return x, position_bias, sa_scores # Self-attention block def _sa_block( @@ -584,7 +572,7 @@ def _sa_block( attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], position_bias: Optional[Tensor], - ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: attn = self.self_attn( x, x, @@ -602,6 +590,130 @@ def _sa_block( return self.dropout1(x), position_bias, scores + # Feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) + return self.dropout3(x) + + +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622 +class T5DecoderLayer(T5EncoderLayer): + r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. + This T5 layer is based on the paper: + "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". + Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, + Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. + Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html + Users may modify or implement in a different way during application. + Args: + is_decoder: Whether or not the layer belongs to the decoder. (required) + d_model: Number of expected features in the input (required). + nhead: Number of heads in the multihead attention models (required). + dim_feedforward: Dimension of the feedforward network model (default=3072). + dropout: Dropout value (default=0.1). + activation: Activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. (default: relu) + layer_norm_eps: The eps value in layer normalization components (default=1e-6). + relative_attention_num_buckets: Number of relative position buckets (default: 32) + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket (default: 128) + compute_relative_attention_bias: Whether or not the relative position embeddings + need to be computed. Typically occurs in the first layer of encoder/decoder + and resulting position embeddings are returned to be passed up to higher layers. (default: False) + + Examples:: + >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12) + >>> memory = torch.rand(32, 10, 768) + >>> tgt = torch.rand(32, 20, 768) + >>> out = deoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 3072, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-6, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + compute_relative_attention_bias: bool = False, + device: Optional[torch.device] = None, + dtype=None, + ) -> None: + super().__init__( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + relative_attention_num_buckets, + relative_attention_max_distance, + compute_relative_attention_bias, + device, + dtype, + ) + + self.cross_attn = T5MultiheadAttention( + d_model, nhead, is_decoder=True, dropout=dropout, device=device, dtype=dtype + ) + self.norm3 = T5LayerNorm(d_model, eps=layer_norm_eps) + self.dropout4 = nn.Dropout(dropout) + + if isinstance(activation, str): + assert activation in ( + "relu", + "gelu", + ), f"Do not support '{activation}' activation. Use either 'relu' or 'gelu'" + if activation == "relu": + self.activation = F.relu + elif activation == "gelu": + self.activation = F.gelu + else: + self.activation = activation + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + r"""Pass the inputs (and mask) through the encoder/decoder layer. + Args: + tgt: Input sequence to the encoder/decoder layer. (required). + Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence + length, and E is the model dimension. + memory: Sequence from the last layer of the encoder (used for decoder only). (required). + Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence + length, and E is the model dimension. + tgt_mask: Attention mask for self-attention. (optional). + Must have shape (Nt, Nt). + memory_mask: Attention mask for cross-attention (decoder-only) (optional). + Must have shape (Nt, Ns). + tgt_key_padding_mask: Mask for the tgt keys per batch (optional). + Must have shape (B, Nt). + memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + Must have shape (B, Ns). + position_bias: Relative attention bias to be used when computing self-attention scores (optional) + Must have shape (B, H, Nt, Nt) where H is the number of heads. + """ + + # See Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + x = tgt + sa_out, position_bias, sa_scores = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, position_bias) + x = x + sa_out + ca_out, ca_scores = self._ca_block(self.norm3(x), memory, memory_mask, memory_key_padding_mask) + x = x + ca_out + x = x + self._ff_block(self.norm2(x)) + + return x, position_bias, sa_scores, ca_scores + # Cross attention block def _ca_block( self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] @@ -611,14 +723,108 @@ def _ca_block( scores = attn[2] return self.dropout4(x), scores - # Feed forward block - def _ff_block(self, x: Tensor) -> Tensor: - x = self.linear2(self.dropout2(self.activation(self.linear1(x)))) - return self.dropout3(x) + +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 +class T5Encoder(nn.Module): + r"""T5 is a stack of N encoder/decoder layers + Args: + is_decoder: Whether or not the layer belongs to the decoder. (required) + d_model: Number of expected features in the input (required). + nhead: Number of heads in the multihead attention models (required). + num_layers: Number of encoder/decoder layers in the stack (required) + dim_feedforward: Dimension of the feedforward network model (default=3072). + dropout: Dropout value (default=0.1). + activation: Activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. (default: relu) + layer_norm_eps: The eps value in layer normalization components (default=1e-6). + relative_attention_num_buckets: Number of relative position buckets (default: 32) + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket (defulat: 128) + Examples:: + >>> decoder = nn.T5Stack(is_decoder=True, d_model=768, nhead=12, num_layers=12) + >>> memory = torch.rand(32, 10, 512) + >>> tgt = torch.rand(32, 10, 512) + >>> out = decoder(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + dim_feedforward: int = 3072, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-6, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, + device: Optional[torch.device] = None, + dtype=None, + ) -> None: + super().__init__() + + self.layers = nn.ModuleList( + [ + T5EncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + activation, + layer_norm_eps, + relative_attention_num_buckets, + relative_attention_max_distance, + compute_relative_attention_bias=True if i == 0 else False, + device=device, + dtype=dtype, + ) + for i in range(num_layers) + ] + ) + self.num_layers = num_layers + + def forward( + self, + tgt: Tensor, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]: + r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. + Args: + tgt: Input sequence to the encoder/decoder layer. (required). + Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence + length, and E is the model dimension. + memory: Sequence from the last layer of the encoder (used for decoder only). (required). + Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence + length, and E is the model dimension. + tgt_mask: Attention mask for self-attention. (optional). + Must have shape (Nt, Nt). + memory_mask: Attention mask for cross-attention (decoder-only) (optional). + Must have shape (Nt, Ns). + tgt_key_padding_mask: Mask for the tgt keys per batch (optional). + Must have shape (B, Nt). + memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + Must have shape (B, Ns). + """ + output = tgt + position_bias = None + all_outputs = torch.jit.annotate(List[Tensor], []) + all_sa_scores = torch.jit.annotate(List[Optional[Tensor]], []) + for mod in self.layers: + all_outputs.append(output) + output, position_bias, sa_score = mod( + output, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + position_bias=position_bias, + ) + all_sa_scores.append(sa_score) + + return output, all_outputs, position_bias, all_sa_scores # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 -class T5Stack(nn.Module): +class T5Decoder(nn.Module): r"""T5 is a stack of N encoder/decoder layers Args: is_decoder: Whether or not the layer belongs to the decoder. (required) @@ -642,7 +848,6 @@ class T5Stack(nn.Module): def __init__( self, - is_decoder: bool, d_model: int, nhead: int, num_layers: int, @@ -659,8 +864,7 @@ def __init__( self.layers = nn.ModuleList( [ - T5Layer( - is_decoder, + T5DecoderLayer( d_model, nhead, dim_feedforward, @@ -686,7 +890,7 @@ def forward( memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tuple[Tensor], Tensor, Tuple[Tensor], Tuple[Tensor]]: + ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]], List[Optional[Tensor]]]: r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. Args: tgt: Input sequence to the encoder/decoder layer. (required). @@ -706,11 +910,11 @@ def forward( """ output = tgt position_bias = None - all_outputs = () - all_sa_scores = () - all_ca_scores = () + all_outputs = torch.jit.annotate(List[Tensor], []) + all_sa_scores = torch.jit.annotate(List[Optional[Tensor]], []) + all_ca_scores = torch.jit.annotate(List[Optional[Tensor]], []) for mod in self.layers: - all_outputs = all_outputs + (output,) + all_outputs.append(output) output, position_bias, sa_score, ca_score = mod( output, memory, @@ -720,7 +924,7 @@ def forward( memory_key_padding_mask=memory_key_padding_mask, position_bias=position_bias, ) - all_sa_scores = all_sa_scores + (sa_score,) - all_ca_scores = all_ca_scores + (ca_score,) + all_sa_scores.append(sa_score) + all_ca_scores.append(ca_score) return output, all_outputs, position_bias, all_sa_scores, all_ca_scores From 2b4f61c8b98d1e1decf10e5ff775436f6686ac6a Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 8 Aug 2022 15:36:02 -0400 Subject: [PATCH 4/6] updating doc strings --- torchtext/prototype/models/t5/modules.py | 87 +++++++++--------------- 1 file changed, 33 insertions(+), 54 deletions(-) diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index 86dde3bac5..e9b749582b 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -48,6 +48,12 @@ def __init__( bias: If specified, adds bias to input / output projection layers. Default: `False`. kdim: Total number of features for keys. Default: `None` (uses `kdim=embed_dim`). vdim: Total number of features for values. Default: `None` (uses `vdim=embed_dim`). + compute_relative_attention_bias: Whether or not the relative position embeddings + need to be computed. Wypically occurs in the first layer of the encoder/decoder + and the resulting position embeddings are returned to be passed up to higher layers. (defualt: False) + relative_attention_num_buckets: Number of relative position buckets. Default: `32` + relative_attention_max_distance: Maximum threshold on the relative distance used to + allocate buckets. Anything larger gets placed in the same bucket. Default: `128` """ super().__init__(embed_dim, num_heads, dropout, bias, False, False, kdim, vdim, True, device, dtype) factory_kwargs = {"device": device, "dtype": dtype} @@ -109,13 +115,6 @@ def forward( average_attn_weights: If true, indicates that the returned `attn_weights` should be averaged across heads. Otherwise, `attn_weights` are provided separately per head. Note that this flag only has an effect when `need_weights=True`. Default: `False` (i.e. average weights across heads) - compute_relative_attention_bias: Whether or not the relative position embeddings - need to be computed. Wypically occurs in the first layer of the encoder/decoder - and the resulting position embeddings are returned to be passed up to higher layers. (defualt: False) - relative_attention_num_buckets: Number of relative position buckets. Default: `32` - relative_attention_max_distance: Maximum threshold on the relative distance used to - allocate buckets. Anything larger gets placed in the same bucket. Default: `128` - relative_attention_bias: nn.Embeding object used to compute relative position embeddings. Default: `None` position_bias: Position bias tensor used if to add relative attention bias to attention scores. Default: `None` Outputs: - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where :math:`N` is the batch size, @@ -449,7 +448,7 @@ def forward(self, hidden_states: Tensor) -> Tensor: # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622 class T5EncoderLayer(nn.Module): - r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. + r"""T5EncoderLayer is made up of a self-attn block and feedforward network. This T5 layer is based on the paper: "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, @@ -457,7 +456,6 @@ class T5EncoderLayer(nn.Module): Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Users may modify or implement in a different way during application. Args: - is_decoder: Whether or not the layer belongs to the decoder. (required) d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). dim_feedforward: Dimension of the feedforward network model (default=3072). @@ -469,14 +467,13 @@ class T5EncoderLayer(nn.Module): relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (default: 128) compute_relative_attention_bias: Whether or not the relative position embeddings - need to be computed. Typically occurs in the first layer of encoder/decoder + need to be computed. Typically occurs in the first layer of the encoder and resulting position embeddings are returned to be passed up to higher layers. (default: False) Examples:: - >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12) - >>> memory = torch.rand(32, 10, 768) + >>> encoder_layer = T5EncoderLayer(d_model=768, nhead=12) >>> tgt = torch.rand(32, 20, 768) - >>> out = deoder_layer(tgt, memory) + >>> out = encoder_layer(tgt) """ def __init__( @@ -537,22 +534,15 @@ def forward( tgt_key_padding_mask: Optional[Tensor] = None, position_bias: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: - r"""Pass the inputs (and mask) through the encoder/decoder layer. + r"""Pass the inputs (and mask) through the encoder layer. Args: - tgt: Input sequence to the encoder/decoder layer. (required). + tgt: Input sequence to the encoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence length, and E is the model dimension. - memory: Sequence from the last layer of the encoder (used for decoder only). (required). - Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence - length, and E is the model dimension. tgt_mask: Attention mask for self-attention. (optional). Must have shape (Nt, Nt). - memory_mask: Attention mask for cross-attention (decoder-only) (optional). - Must have shape (Nt, Ns). tgt_key_padding_mask: Mask for the tgt keys per batch (optional). Must have shape (B, Nt). - memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). - Must have shape (B, Ns). position_bias: Relative attention bias to be used when computing self-attention scores (optional) Must have shape (B, H, Nt, Nt) where H is the number of heads. """ @@ -598,7 +588,7 @@ def _ff_block(self, x: Tensor) -> Tensor: # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L622 class T5DecoderLayer(T5EncoderLayer): - r"""T5Layer is made up of self-attn, cross-attn (decoder only) and feedforward network. + r"""T5DecoderLayer is made up of a self-attn block, cross-attn block, and feedforward network. This T5 layer is based on the paper: "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, @@ -606,7 +596,6 @@ class T5DecoderLayer(T5EncoderLayer): Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Users may modify or implement in a different way during application. Args: - is_decoder: Whether or not the layer belongs to the decoder. (required) d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). dim_feedforward: Dimension of the feedforward network model (default=3072). @@ -618,14 +607,14 @@ class T5DecoderLayer(T5EncoderLayer): relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (default: 128) compute_relative_attention_bias: Whether or not the relative position embeddings - need to be computed. Typically occurs in the first layer of encoder/decoder + need to be computed. Typically occurs in the first layer of the decoder and resulting position embeddings are returned to be passed up to higher layers. (default: False) Examples:: - >>> decoder_layer = T5Layer(is_decoder=True, d_model=768, nhead=12) + >>> decoder_layer = T5DecoderLayer(d_model=768, nhead=12) >>> memory = torch.rand(32, 10, 768) >>> tgt = torch.rand(32, 20, 768) - >>> out = deoder_layer(tgt, memory) + >>> out = decoder_layer(tgt, memory) """ def __init__( @@ -686,19 +675,19 @@ def forward( ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]]: r"""Pass the inputs (and mask) through the encoder/decoder layer. Args: - tgt: Input sequence to the encoder/decoder layer. (required). + tgt: Input sequence to the decoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence length, and E is the model dimension. - memory: Sequence from the last layer of the encoder (used for decoder only). (required). + memory: Sequence from the last layer of the encoder. (required). Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence length, and E is the model dimension. tgt_mask: Attention mask for self-attention. (optional). Must have shape (Nt, Nt). - memory_mask: Attention mask for cross-attention (decoder-only) (optional). + memory_mask: Attention mask for cross-attention (optional). Must have shape (Nt, Ns). tgt_key_padding_mask: Mask for the tgt keys per batch (optional). Must have shape (B, Nt). - memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + memory_key_padding_mask: Mask for the memory keys per batch (optional). Must have shape (B, Ns). position_bias: Relative attention bias to be used when computing self-attention scores (optional) Must have shape (B, H, Nt, Nt) where H is the number of heads. @@ -726,12 +715,11 @@ def _ca_block( # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 class T5Encoder(nn.Module): - r"""T5 is a stack of N encoder/decoder layers + r"""T5Encoder is a stack of N encoder layers Args: - is_decoder: Whether or not the layer belongs to the decoder. (required) d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). - num_layers: Number of encoder/decoder layers in the stack (required) + num_layers: Number of encoder layers in the stack (required) dim_feedforward: Dimension of the feedforward network model (default=3072). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string @@ -741,10 +729,9 @@ class T5Encoder(nn.Module): relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (defulat: 128) Examples:: - >>> decoder = nn.T5Stack(is_decoder=True, d_model=768, nhead=12, num_layers=12) - >>> memory = torch.rand(32, 10, 512) + >>> encoder = T5Encoder(d_model=768, nhead=12, num_layers=12) >>> tgt = torch.rand(32, 10, 512) - >>> out = decoder(tgt, memory) + >>> out = encoder(tgt) """ def __init__( @@ -789,22 +776,15 @@ def forward( tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]: - r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. + r"""Pass the inputs (and mask) through the stack of encoder layers. Args: - tgt: Input sequence to the encoder/decoder layer. (required). + tgt: Input sequence to the encoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence length, and E is the model dimension. - memory: Sequence from the last layer of the encoder (used for decoder only). (required). - Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence - length, and E is the model dimension. tgt_mask: Attention mask for self-attention. (optional). Must have shape (Nt, Nt). - memory_mask: Attention mask for cross-attention (decoder-only) (optional). - Must have shape (Nt, Ns). tgt_key_padding_mask: Mask for the tgt keys per batch (optional). Must have shape (B, Nt). - memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). - Must have shape (B, Ns). """ output = tgt position_bias = None @@ -825,12 +805,11 @@ def forward( # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L835 class T5Decoder(nn.Module): - r"""T5 is a stack of N encoder/decoder layers + r"""T5Decoder is a stack of N decoder layers Args: - is_decoder: Whether or not the layer belongs to the decoder. (required) d_model: Number of expected features in the input (required). nhead: Number of heads in the multihead attention models (required). - num_layers: Number of encoder/decoder layers in the stack (required) + num_layers: Number of decoder layers in the stack (required) dim_feedforward: Dimension of the feedforward network model (default=3072). dropout: Dropout value (default=0.1). activation: Activation function of the intermediate layer, can be a string @@ -840,7 +819,7 @@ class T5Decoder(nn.Module): relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (defulat: 128) Examples:: - >>> decoder = nn.T5Stack(is_decoder=True, d_model=768, nhead=12, num_layers=12) + >>> decoder = T5Decoder(d_model=768, nhead=12, num_layers=12) >>> memory = torch.rand(32, 10, 512) >>> tgt = torch.rand(32, 10, 512) >>> out = decoder(tgt, memory) @@ -893,19 +872,19 @@ def forward( ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]], List[Optional[Tensor]]]: r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. Args: - tgt: Input sequence to the encoder/decoder layer. (required). + tgt: Input sequence to the decoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence length, and E is the model dimension. - memory: Sequence from the last layer of the encoder (used for decoder only). (required). + memory: Sequence from the last layer of the encoder. (required). Must have shape (B, Nts, E) where B is the batch size, Ns is the source sequence length, and E is the model dimension. tgt_mask: Attention mask for self-attention. (optional). Must have shape (Nt, Nt). - memory_mask: Attention mask for cross-attention (decoder-only) (optional). + memory_mask: Attention mask for cross-attention (optional). Must have shape (Nt, Ns). tgt_key_padding_mask: Mask for the tgt keys per batch (optional). Must have shape (B, Nt). - memory_key_padding_mask: Mask for the memory keys per batch (decoder-only) (optional). + memory_key_padding_mask: Mask for the memory keys per batch (optional). Must have shape (B, Ns). """ output = tgt From d595a7fb153a1fcd5bd9ce2b1816d7492d740aa3 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 9 Aug 2022 11:11:54 -0400 Subject: [PATCH 5/6] correcting type annotations --- torchtext/prototype/models/t5/model.py | 7 ++++++- torchtext/prototype/models/t5/modules.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 63bfcc8109..7113dfd9d1 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -136,7 +136,7 @@ def __init__( def forward( self, encoder_tokens: Tensor, - decoder_tokens: Tensor = None, + decoder_tokens: Optional[Tensor] = None, encoder_mask: Optional[Tensor] = None, decoder_mask: Optional[Tensor] = None, ) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]: @@ -183,6 +183,7 @@ def forward( decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx if decoder_mask is None: + assert decoder_tokens is not None and decoder_tokens.dim() == 2 tgt_len = decoder_tokens.shape[1] decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1) decoder_mask = decoder_mask.to(torch.bool) @@ -232,4 +233,8 @@ def forward( "encoder_sa_scores": encoder_sa, } + assert torch.jit.isinstance( + t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]] + ) + return t5_output diff --git a/torchtext/prototype/models/t5/modules.py b/torchtext/prototype/models/t5/modules.py index e9b749582b..806099f4ea 100644 --- a/torchtext/prototype/models/t5/modules.py +++ b/torchtext/prototype/models/t5/modules.py @@ -31,11 +31,11 @@ def __init__( is_decoder: bool = False, dropout: float = 0.0, bias: bool = False, - kdim: int = None, - vdim: int = None, - compute_relative_attention_bias=False, - relative_attention_num_buckets=32, - relative_attention_max_distance=128, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + compute_relative_attention_bias: bool = False, + relative_attention_num_buckets: int = 32, + relative_attention_max_distance: int = 128, device: Optional[torch.device] = None, dtype=None, ) -> None: @@ -416,7 +416,7 @@ def _relative_position_bucket( # NOTE: Taken from https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L239 class T5LayerNorm(nn.Module): - def __init__(self, d_model, eps=1e-6) -> None: + def __init__(self, d_model: int, eps: float = 1e-6) -> None: """ Construct a layernorm module in the T5 style. No bias and no subtraction of mean. """ @@ -776,7 +776,7 @@ def forward( tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]: - r"""Pass the inputs (and mask) through the stack of encoder layers. + r"""Pass the input (and masks) through the stack of encoder layers. Args: tgt: Input sequence to the encoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence @@ -864,13 +864,13 @@ def __init__( def forward( self, tgt: Tensor, - memory: Tensor = None, + memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]], List[Optional[Tensor]]]: - r"""Pass the inputs (and mask) through the stack of encoder/decoder layers. + r"""Pass the inputs (and masks) through the stack of decoder layers. Args: tgt: Input sequence to the decoder layer. (required). Must have shape (B, Nt, E) where B is the batch size, Nt is the target sequence From 5441e9a7df14d0e11d115b22e7679b418cbebf92 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 9 Aug 2022 11:13:07 -0400 Subject: [PATCH 6/6] update integration tests to test scripted version of models --- .../integration_tests/test_models.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/test/prototype/integration_tests/test_models.py b/test/prototype/integration_tests/test_models.py index ed0f5b29e0..99da027e5e 100644 --- a/test/prototype/integration_tests/test_models.py +++ b/test/prototype/integration_tests/test_models.py @@ -1,4 +1,5 @@ import torch +from parameterized import parameterized from test.common.assets import get_asset_path from test.common.torchtext_test_case import TorchtextTestCase from torchtext.prototype.models import ( @@ -9,7 +10,7 @@ class TestT5(TorchtextTestCase): - def _t5_model(self, t5_model, expected_asset_name, test_text): + def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text): """Verify that pre-trained T5 models in torchtext produce the same output as the HuggingFace reference implementation. """ @@ -18,6 +19,10 @@ def _t5_model(self, t5_model, expected_asset_name, test_text): model = t5_model.get_model() model = model.eval() + if is_jit: + transform = torch.jit.script(transform) + model = torch.jit.script(model) + model_input = transform(test_text) if model.encoder_only: actual = model(model_input)["encoder_output"] @@ -27,17 +32,24 @@ def _t5_model(self, t5_model, expected_asset_name, test_text): expected = torch.load(expected_asset_path) torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06) - def test_t5_base_encoder_model(self) -> None: + @parameterized.expand([("jit", True), ("not_jit", False)]) + def test_t5_base_encoder_model(self, name, is_jit) -> None: expected_asset_name = "t5.base.encoder.output.pt" test_text = ["Hello world", "Attention rocks!"] - self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text) + self._t5_model( + is_jit=is_jit, t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text + ) - def test_t5_base_model(self) -> None: + @parameterized.expand([("jit", True), ("not_jit", False)]) + def test_t5_base_model(self, name, is_jit) -> None: expected_asset_name = "t5.base.output.pt" test_text = ["Hello world", "Attention rocks!"] - self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text) + self._t5_model(is_jit=is_jit, t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text) - def test_t5_base_generation_model(self) -> None: + @parameterized.expand([("jit", True), ("not_jit", False)]) + def test_t5_base_generation_model(self, name, is_jit) -> None: expected_asset_name = "t5.base.generation.output.pt" test_text = ["Hello world", "Attention rocks!"] - self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text) + self._t5_model( + is_jit=is_jit, t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text + )