|
| 1 | +from typing import Dict, Optional, Tuple, Union, Callable |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from torch import Tensor |
| 6 | + |
| 7 | +from .modules import T5Stack, T5LayerNorm |
| 8 | + |
| 9 | + |
| 10 | +# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269 |
| 11 | +class T5Model(nn.Module): |
| 12 | + r"""A T5 model. User is able to modify the attributes as needed. The architecture |
| 13 | + is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer". |
| 14 | + Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, |
| 15 | + Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. |
| 16 | + Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html |
| 17 | + Args: |
| 18 | + encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) |
| 19 | + d_model: Number of expected features in the encoder/decoder inputs (default=768). |
| 20 | + nhead: Number of heads in the multiheadattention models (default=12). |
| 21 | + num_encoder_layers: Number of encoder layers in the encoder (default=12). |
| 22 | + num_decoder_layers: Number of decoder layers in the decoder (default=12). |
| 23 | + dim_feedforward: Dimension of the feedforward network model (default=3072). |
| 24 | + dropout: Dropout value (default=0.1). |
| 25 | + activation: Activation function of encoder/decoder intermediate layer, can be a string |
| 26 | + ("relu" or "gelu") or a unary callable. Default: relu |
| 27 | + layer_norm_eps: The eps value in layer normalization components (default=1e-6). |
| 28 | + relative_attention_num_buckets: Number of relative position buckets (default: 32) |
| 29 | + relative_attention_max_distance: Maximum threshold on the relative distance used to |
| 30 | + allocate buckets. Anything larger gets placed in the same bucket (default: 128) |
| 31 | + padding_idx: Index assigned to padding token in vocabulary (default: 0) |
| 32 | + max_seq_len: Maximum sequence length (default: 512) |
| 33 | + vocab_size: Size of vocabulary (default: 32128) |
| 34 | + Examples:: |
| 35 | + >>> t5_model = T5Model(encoder_only=False) |
| 36 | + >>> src = torch.rand((32, 10, 512)) |
| 37 | + >>> tgt = torch.rand((32, 20, 512)) |
| 38 | + >>> out = t5_model(src, tgt) |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + encoder_only: bool, |
| 44 | + d_model: int = 768, |
| 45 | + nhead: int = 12, |
| 46 | + num_encoder_layers: int = 12, |
| 47 | + num_decoder_layers: int = 12, |
| 48 | + dim_feedforward: int = 3072, |
| 49 | + dropout: float = 0.1, |
| 50 | + activation: Union[str, Callable[[Tensor], Tensor]] = "relu", |
| 51 | + layer_norm_eps: float = 1e-6, |
| 52 | + relative_attention_num_buckets: int = 32, |
| 53 | + relative_attention_max_distance: int = 128, |
| 54 | + padding_idx: int = 0, |
| 55 | + max_seq_len: int = 512, |
| 56 | + vocab_size: int = 32128, |
| 57 | + device=None, |
| 58 | + dtype=None, |
| 59 | + ) -> None: |
| 60 | + super().__init__() |
| 61 | + |
| 62 | + self.encoder_only = encoder_only |
| 63 | + self.d_model = d_model |
| 64 | + self.dim_feedforward = dim_feedforward |
| 65 | + self.dropout = dropout |
| 66 | + self.activation = activation |
| 67 | + self.layer_norm_eps = layer_norm_eps |
| 68 | + self.nhead = nhead |
| 69 | + self.num_encoder_layers = num_encoder_layers |
| 70 | + self.num_decoder_layers = num_decoder_layers |
| 71 | + self.relative_attention_num_buckets = relative_attention_num_buckets |
| 72 | + self.realtive_attention_max_distance = relative_attention_max_distance |
| 73 | + self.padding_idx = padding_idx |
| 74 | + self.max_seq_len = max_seq_len |
| 75 | + self.vocab_size = vocab_size |
| 76 | + self.device = device |
| 77 | + self.dtype = dtype |
| 78 | + |
| 79 | + self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx) |
| 80 | + self.encoder = T5Stack( |
| 81 | + is_decoder=False, |
| 82 | + d_model=d_model, |
| 83 | + nhead=nhead, |
| 84 | + num_layers=num_encoder_layers, |
| 85 | + dim_feedforward=dim_feedforward, |
| 86 | + dropout=dropout, |
| 87 | + activation=activation, |
| 88 | + layer_norm_eps=layer_norm_eps, |
| 89 | + relative_attention_num_buckets=relative_attention_num_buckets, |
| 90 | + relative_attention_max_distance=relative_attention_max_distance, |
| 91 | + device=device, |
| 92 | + dtype=dtype, |
| 93 | + ) |
| 94 | + self.norm1 = T5LayerNorm(d_model) |
| 95 | + self.dropout1 = nn.Dropout(dropout) |
| 96 | + self.dropout2 = nn.Dropout(dropout) |
| 97 | + |
| 98 | + if not encoder_only: |
| 99 | + self.decoder = T5Stack( |
| 100 | + is_decoder=True, |
| 101 | + d_model=d_model, |
| 102 | + nhead=nhead, |
| 103 | + num_layers=num_decoder_layers, |
| 104 | + dim_feedforward=dim_feedforward, |
| 105 | + dropout=dropout, |
| 106 | + activation=activation, |
| 107 | + layer_norm_eps=layer_norm_eps, |
| 108 | + relative_attention_num_buckets=relative_attention_num_buckets, |
| 109 | + relative_attention_max_distance=relative_attention_max_distance, |
| 110 | + device=device, |
| 111 | + dtype=dtype, |
| 112 | + ) |
| 113 | + self.norm2 = T5LayerNorm(d_model) |
| 114 | + self.dropout3 = nn.Dropout(dropout) |
| 115 | + self.dropout4 = nn.Dropout(dropout) |
| 116 | + |
| 117 | + def forward( |
| 118 | + self, |
| 119 | + encoder_tokens: Tensor, |
| 120 | + decoder_tokens: Tensor = None, |
| 121 | + encoder_mask: Optional[Tensor] = None, |
| 122 | + decoder_mask: Optional[Tensor] = None, |
| 123 | + ) -> Dict[str, Union[Tensor, Tuple[Tensor]]]: |
| 124 | + r"""Pass the inputs (and mask) through the decoder layer in turn. |
| 125 | + Args: |
| 126 | + encoder_tokens: Tokenized input sequence to the encoder. |
| 127 | + Must be batch first with shape (B, Ne) where B is the batch size and Ne is the |
| 128 | + encoder input sequence length. (required). |
| 129 | + decoder_tokens: Tokenized input sequence to the decoder. |
| 130 | + Must be batch first with shape (B, Nd) where B is the batch size and Nd is the |
| 131 | + decoder input sequence length. (required). |
| 132 | + encoder_mask: Self-attention mask for the encoder input sequence. |
| 133 | + Must have shape (Ne, Ne) (optional). |
| 134 | + decoder_mask: Self-attention mask for the decoder input sequence. |
| 135 | + Must have shape (Nd, Nd) (optional). |
| 136 | + Returns: |
| 137 | + encoder_output: Output Tensor from the final layer of the encoder |
| 138 | + encoder_hidden_states: Tuple of output Tensors from each layer of the encoder |
| 139 | + encoder_position_bias: Tensor of relative attention bias computed for input sequence to encoder |
| 140 | + encoder_sa_scores: Tuple of self-attention scores computed at each layer of the encoder |
| 141 | + decoder_output: Output Tensor from the final layer of the decoder |
| 142 | + decoder_hidden_states: Tuple of output Tensors from each layer of the decoder |
| 143 | + decoder_position_bias: Tensor of relative attention bias computed for input sequence to decoder |
| 144 | + encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder |
| 145 | + encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder |
| 146 | + """ |
| 147 | + encoder_padding_mask = encoder_tokens.eq(self.padding_idx) |
| 148 | + encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens)) |
| 149 | + encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder( |
| 150 | + encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask |
| 151 | + ) |
| 152 | + |
| 153 | + encoder_output = self.norm1(encoder_output) |
| 154 | + encoder_output = self.dropout2(encoder_output) |
| 155 | + encoder_hidden_states = encoder_hidden_states + (encoder_output,) |
| 156 | + |
| 157 | + if not self.encoder_only: |
| 158 | + assert decoder_tokens is not None |
| 159 | + if decoder_mask is None: |
| 160 | + tgt_len = decoder_tokens.shape[1] |
| 161 | + decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() |
| 162 | + |
| 163 | + decoder_padding_mask = decoder_tokens.eq(self.padding_idx) |
| 164 | + # T5 implemention uses padding idx to start sequence. Want to ignore this when masking |
| 165 | + decoder_padding_mask[:, 0] = False |
| 166 | + |
| 167 | + decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens)) |
| 168 | + decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder( |
| 169 | + decoder_embeddings, |
| 170 | + memory=encoder_output, |
| 171 | + tgt_mask=decoder_mask, |
| 172 | + memory_mask=encoder_mask, |
| 173 | + tgt_key_padding_mask=decoder_padding_mask, |
| 174 | + memory_key_padding_mask=encoder_padding_mask, |
| 175 | + ) |
| 176 | + |
| 177 | + decoder_output = self.norm2(decoder_output) |
| 178 | + decoder_output = self.dropout4(decoder_output) |
| 179 | + decoder_hidden_states = decoder_hidden_states + (decoder_output,) |
| 180 | + |
| 181 | + t5_output = { |
| 182 | + "encoder_output": encoder_output, |
| 183 | + "encoder_hidden_states": encoder_hidden_states, |
| 184 | + "encoder_position_bias": encoder_position_bias, |
| 185 | + "encoder_sa_scores": encoder_sa, |
| 186 | + "decoder_output": decoder_output, |
| 187 | + "decoder_hidden_states": decoder_hidden_states, |
| 188 | + "decoder_position_bias": decoder_position_bias, |
| 189 | + "decoder_sa_scores": decoder_sa, |
| 190 | + "decoder_ca_scores": decoder_ca, |
| 191 | + } |
| 192 | + else: |
| 193 | + t5_output = { |
| 194 | + "encoder_output": encoder_output, |
| 195 | + "encoder_hidden_states": encoder_hidden_states, |
| 196 | + "encoder_position_bias": encoder_position_bias, |
| 197 | + "encoder_sa_scores": encoder_sa, |
| 198 | + } |
| 199 | + |
| 200 | + return t5_output |
0 commit comments