diff --git a/torchtext/prototype/models/__init__.py b/torchtext/prototype/models/__init__.py new file mode 100644 index 0000000000..ab659dda3d --- /dev/null +++ b/torchtext/prototype/models/__init__.py @@ -0,0 +1 @@ +from .t5 import * # noqa: F401, F403 diff --git a/torchtext/prototype/models/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py new file mode 100644 index 0000000000..f69829494d --- /dev/null +++ b/torchtext/prototype/models/t5/__init__.py @@ -0,0 +1,14 @@ +from .bundler import ( + T5_BASE_ENCODER, + T5_BASE, + T5Bundle, +) +from .model import T5Conf, T5Model + +__all__ = [ + "T5Conf", + "T5Model", + "T5Bundle", + "T5_BASE_ENCODER", + "T5_BASE", +] diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py new file mode 100644 index 0000000000..ed2262db82 --- /dev/null +++ b/torchtext/prototype/models/t5/bundler.py @@ -0,0 +1,167 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union +from urllib.parse import urljoin + +import torch +from torchtext import _TEXT_BUCKET +from torchtext._download_hooks import load_state_dict_from_url + +from .model import T5Conf, T5Model + +logger = logging.getLogger(__name__) + + +@dataclass +class T5Bundle: + """T5Bundle(_config: torchtext.prototype.models.T5Conf, _path: Optional[str] = None) + + Example - Pretrained base t5 encoder + >>> import torch, torchtext + >>> t5_encoder_base = torchtext.prototype.models.T5_BASE_ENCODER + >>> model = t5_encoder_base.get_model() + >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> output = model(model_input)['encoder_output'] + >>> output.shape + torch.Size([2, 6, 768]) + + Example - Pretrained base t5 model + >>> import torch, torchtext + >>> t5_base = torchtext.prototype.models.T5_BASE + >>> model = t5_base.get_model() + >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> output = model(model_input)['decoder_output'] + >>> output.shape + torch.Size([2, 1, 768]) + + Example - User-specified configuration and checkpoint + >>> from torchtext.prototype.models import T5Conf, T5Bundle + >>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt" + >>> encoder_conf = T5Conf(encoder_only=True) + >>> model = T5Bundle.build_model(config=encoder_conf, checkpoint=model_weights_path) + """ + + _config: T5Conf + _path: Optional[str] = None + + def get_model( + self, + *, + load_weights: bool = True, + freeze_model: bool = False, + dl_kwargs: Dict[str, Any] = None, + ) -> T5Model: + r"""get_model(load_weights: bool = True, freeze_model: bool = False, *, dl_kwargs=None) -> torctext.prototype.models.T5Model + + Args: + load_weights (bool): Indicates whether or not to load weights if available. (Default: `True`) + freeze_model (bool): Indicates whether or not to freeze the model weights. (Default: `False`) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) + """ + + if load_weights: + assert ( + self._path is not None + ), "load_weights cannot be True. The pre-trained model weights are not available for the current object" + + if freeze_model: + if not load_weights or not self._path: + logger.warning( + "The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights." + ) + + return T5Bundle.build_model( + config=self._config, + freeze_model=freeze_model, + checkpoint=self._path if load_weights else None, + strict=True, + dl_kwargs=dl_kwargs, + ) + + @classmethod + def build_model( + cls, + config: T5Conf, + *, + freeze_model: bool = False, + checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, + strict=False, + dl_kwargs: Dict[str, Any] = None, + ) -> T5Model: + """Class builder method + + Args: + config (T5Conf): An instance of classT5Conf that defined the model configuration + freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`) + checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``) + strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) + """ + + model = T5Model(config, freeze_model) + if checkpoint is not None: + if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): + state_dict = checkpoint + elif isinstance(checkpoint, str): + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs) + else: + raise TypeError( + "checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint)) + ) + + model.load_state_dict(state_dict, strict=strict) + + return model + + @property + def config(self) -> T5Conf: + return self._config + + +T5_BASE_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), + _config=T5Conf(encoder_only=True), +) + +T5_BASE_ENCODER.__doc__ = """ + T5 Encoder with Base configuration + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `. It introduces a unified framework that converts text-based + language problems, such as translation, question-answering, and summarization, into a text-to-text format. The + Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task, + and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version + of the canonical Transformer architecture. + + Originally published by the authors of T5 under Apache License, Version 2.0 + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. + """ + + +T5_BASE = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), + _config=T5Conf(encoder_only=False), +) + +T5_BASE.__doc__ = """ + T5 Encoder-Decoder with Base configuration + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `. It introduces a unified framework that converts text-based + language problems, such as translation, question-answering, and summarization, into a text-to-text format. The + Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task, + and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version + of the canonical Transformer architecture. + + Originally published by the authors of T5 under Apache License, Version 2.0 + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. + """ diff --git a/torchtext/prototype/t5/model.py b/torchtext/prototype/models/t5/model.py similarity index 55% rename from torchtext/prototype/t5/model.py rename to torchtext/prototype/models/t5/model.py index cd9ebf2367..41ffb12b4e 100644 --- a/torchtext/prototype/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union, Callable import torch @@ -7,6 +8,25 @@ from .modules import T5Stack, T5LayerNorm +@dataclass +class T5Conf: + encoder_only: bool = False + embedding_dim: int = 768 + num_attention_heads: int = 12 + num_encoder_layers: int = 12 + num_decoder_layers: int = 12 + ffn_dimension: int = 3072 + dropout: float = 0.1 + activation: Union[str, Callable[[Tensor], Tensor]] = "relu" + layer_norm_eps: float = 1e-6 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + padding_idx: int = 0 + max_seq_len: int = 512 + vocab_size: int = 32128 + training: bool = False + + # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269 class T5Model(nn.Module): r"""A T5 model. User is able to modify the attributes as needed. The architecture @@ -15,104 +35,92 @@ class T5Model(nn.Module): Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Args: - encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) - d_model: Number of expected features in the encoder/decoder inputs (default=768). - nhead: Number of heads in the multiheadattention models (default=12). - num_encoder_layers: Number of encoder layers in the encoder (default=12). - num_decoder_layers: Number of decoder layers in the decoder (default=12). - dim_feedforward: Dimension of the feedforward network model (default=3072). - dropout: Dropout value (default=0.1). - activation: Activation function of encoder/decoder intermediate layer, can be a string + config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) + config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768). + config.num_attention_heads: Number of heads in the multiheadattention models (default=12). + config.num_encoder_layers: Number of encoder layers in the encoder (default=12). + config.num_decoder_layers: Number of decoder layers in the decoder (default=12). + config.ffn_dimension: Dimension of the feedforward network model (default=3072). + config.dropout: Dropout value (default=0.1). + config.activation: Activation function of encoder/decoder intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu - layer_norm_eps: The eps value in layer normalization components (default=1e-6). - relative_attention_num_buckets: Number of relative position buckets (default: 32) - relative_attention_max_distance: Maximum threshold on the relative distance used to + config.layer_norm_eps: The eps value in layer normalization components (default=1e-6). + config.relative_attention_num_buckets: Number of relative position buckets (default: 32) + config.relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (default: 128) - padding_idx: Index assigned to padding token in vocabulary (default: 0) - max_seq_len: Maximum sequence length (default: 512) - vocab_size: Size of vocabulary (default: 32128) - Examples:: - >>> t5_model = T5Model(encoder_only=False) - >>> src = torch.rand((32, 10, 512)) - >>> tgt = torch.rand((32, 20, 512)) - >>> out = t5_model(src, tgt) + config.padding_idx: Index assigned to padding token in vocabulary (default: 0) + config.max_seq_len: Maximum sequence length (default: 512) + config.vocab_size: Size of vocabulary (default: 32128) + config.training: Whether or not to apply dropout (default: False) + freeze: Indicates whether or not to freeze the model weights. (default: False) + Examples: + >>> from torchtext.prototype.models import T5Conf, T5Model + >>> t5_config = T5Conf(encoder_only=False) + >>> t5_model = T5Model(t5_config) + >>> encoder_input = torch.rand((32, 10, 512)) + >>> decoder_input = torch.rand((32, 20, 512)) + >>> out = t5_model(encoder_input, decoder_input) """ def __init__( self, - encoder_only: bool, - d_model: int = 768, - nhead: int = 12, - num_encoder_layers: int = 12, - num_decoder_layers: int = 12, - dim_feedforward: int = 3072, - dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = "relu", - layer_norm_eps: float = 1e-6, - relative_attention_num_buckets: int = 32, - relative_attention_max_distance: int = 128, - padding_idx: int = 0, - max_seq_len: int = 512, - vocab_size: int = 32128, + config: T5Conf, + freeze: bool = False, device=None, dtype=None, ) -> None: super().__init__() - self.encoder_only = encoder_only - self.d_model = d_model - self.dim_feedforward = dim_feedforward - self.dropout = dropout - self.activation = activation - self.layer_norm_eps = layer_norm_eps - self.nhead = nhead - self.num_encoder_layers = num_encoder_layers - self.num_decoder_layers = num_decoder_layers - self.relative_attention_num_buckets = relative_attention_num_buckets - self.realtive_attention_max_distance = relative_attention_max_distance - self.padding_idx = padding_idx - self.max_seq_len = max_seq_len - self.vocab_size = vocab_size + assert isinstance(config, T5Conf) + + self.encoder_only = config.encoder_only + self.padding_idx = config.padding_idx + self.training = config.training + self.dropout = config.dropout if config.training else 0.0 self.device = device self.dtype = dtype - self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx) + self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx) self.encoder = T5Stack( is_decoder=False, - d_model=d_model, - nhead=nhead, - num_layers=num_encoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, + d_model=config.embedding_dim, + nhead=config.num_attention_heads, + num_layers=config.num_encoder_layers, + dim_feedforward=config.ffn_dimension, + dropout=self.dropout, + activation=config.activation, + layer_norm_eps=config.layer_norm_eps, + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, device=device, dtype=dtype, ) - self.norm1 = T5LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) + self.norm1 = T5LayerNorm(config.embedding_dim) + self.dropout1 = nn.Dropout(self.dropout) + self.dropout2 = nn.Dropout(self.dropout) - if not encoder_only: + if not config.encoder_only: self.decoder = T5Stack( is_decoder=True, - d_model=d_model, - nhead=nhead, - num_layers=num_decoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, + d_model=config.embedding_dim, + nhead=config.num_attention_heads, + num_layers=config.num_decoder_layers, + dim_feedforward=config.ffn_dimension, + dropout=self.dropout, + activation=config.activation, + layer_norm_eps=config.layer_norm_eps, + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, device=device, dtype=dtype, ) - self.norm2 = T5LayerNorm(d_model) - self.dropout3 = nn.Dropout(dropout) - self.dropout4 = nn.Dropout(dropout) + self.norm2 = T5LayerNorm(config.embedding_dim) + self.dropout3 = nn.Dropout(self.dropout) + self.dropout4 = nn.Dropout(self.dropout) + + if freeze: + for p in self.parameters(): + p.requires_grad = False def forward( self, @@ -128,7 +136,8 @@ def forward( encoder input sequence length. (required). decoder_tokens: Tokenized input sequence to the decoder. Must be batch first with shape (B, Nd) where B is the batch size and Nd is the - decoder input sequence length. (required). + decoder input sequence length. If None and model is encoder-decoder, will initialize decoder + input sequence to begin with padding index. (optional). encoder_mask: Self-attention mask for the encoder input sequence. Must have shape (Ne, Ne) (optional). decoder_mask: Self-attention mask for the decoder input sequence. @@ -155,7 +164,11 @@ def forward( encoder_hidden_states = encoder_hidden_states + (encoder_output,) if not self.encoder_only: - assert decoder_tokens is not None + + # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. + if decoder_tokens is None: + decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx + if decoder_mask is None: tgt_len = decoder_tokens.shape[1] decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/models/t5/modules.py similarity index 100% rename from torchtext/prototype/t5/modules.py rename to torchtext/prototype/models/t5/modules.py