From 16349d0306246f2e29240df931daba060cbd82f3 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 18:50:33 -0400 Subject: [PATCH 01/10] create T5 config class --- torchtext/prototype/t5/model.py | 142 +++++++++++++++++--------------- 1 file changed, 74 insertions(+), 68 deletions(-) diff --git a/torchtext/prototype/t5/model.py b/torchtext/prototype/t5/model.py index cd9ebf2367..d367ed4805 100644 --- a/torchtext/prototype/t5/model.py +++ b/torchtext/prototype/t5/model.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union, Callable import torch @@ -7,6 +8,25 @@ from .modules import T5Stack, T5LayerNorm +@dataclass +class T5Conf: + encoder_only: bool = False + embedding_dim: int = 768 + num_attention_heads: int = 12 + num_encoder_layers: int = 12 + num_decoder_layers: int = 12 + ffn_dimension: int = 3072 + dropout: float = 0.1 + activation: Union[str, Callable[[Tensor], Tensor]] = "relu" + layer_norm_eps: float = 1e-6 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + padding_idx: int = 0 + max_seq_len: int = 512 + vocab_size: int = 32128 + training: bool = False + + # NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269 class T5Model(nn.Module): r"""A T5 model. User is able to modify the attributes as needed. The architecture @@ -15,22 +35,24 @@ class T5Model(nn.Module): Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research. Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html Args: - encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) - d_model: Number of expected features in the encoder/decoder inputs (default=768). - nhead: Number of heads in the multiheadattention models (default=12). - num_encoder_layers: Number of encoder layers in the encoder (default=12). - num_decoder_layers: Number of decoder layers in the decoder (default=12). - dim_feedforward: Dimension of the feedforward network model (default=3072). - dropout: Dropout value (default=0.1). - activation: Activation function of encoder/decoder intermediate layer, can be a string + config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (required) + config.embedding_dim: Number of expected features in the encoder/decoder inputs (default=768). + config.num_attention_heads: Number of heads in the multiheadattention models (default=12). + config.num_encoder_layers: Number of encoder layers in the encoder (default=12). + config.num_decoder_layers: Number of decoder layers in the decoder (default=12). + config.ffn_dimension: Dimension of the feedforward network model (default=3072). + config.dropout: Dropout value (default=0.1). + config.activation: Activation function of encoder/decoder intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu - layer_norm_eps: The eps value in layer normalization components (default=1e-6). - relative_attention_num_buckets: Number of relative position buckets (default: 32) - relative_attention_max_distance: Maximum threshold on the relative distance used to + config.layer_norm_eps: The eps value in layer normalization components (default=1e-6). + config.relative_attention_num_buckets: Number of relative position buckets (default: 32) + config.relative_attention_max_distance: Maximum threshold on the relative distance used to allocate buckets. Anything larger gets placed in the same bucket (default: 128) - padding_idx: Index assigned to padding token in vocabulary (default: 0) - max_seq_len: Maximum sequence length (default: 512) - vocab_size: Size of vocabulary (default: 32128) + config.padding_idx: Index assigned to padding token in vocabulary (default: 0) + config.max_seq_len: Maximum sequence length (default: 512) + config.vocab_size: Size of vocabulary (default: 32128) + config.training: Whether or not to apply dropout (default: False) + freeze: Examples:: >>> t5_model = T5Model(encoder_only=False) >>> src = torch.rand((32, 10, 512)) @@ -40,79 +62,63 @@ class T5Model(nn.Module): def __init__( self, - encoder_only: bool, - d_model: int = 768, - nhead: int = 12, - num_encoder_layers: int = 12, - num_decoder_layers: int = 12, - dim_feedforward: int = 3072, - dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = "relu", - layer_norm_eps: float = 1e-6, - relative_attention_num_buckets: int = 32, - relative_attention_max_distance: int = 128, - padding_idx: int = 0, - max_seq_len: int = 512, - vocab_size: int = 32128, + config: T5Conf, + freeze: bool = False, device=None, dtype=None, ) -> None: super().__init__() - self.encoder_only = encoder_only - self.d_model = d_model - self.dim_feedforward = dim_feedforward - self.dropout = dropout - self.activation = activation - self.layer_norm_eps = layer_norm_eps - self.nhead = nhead - self.num_encoder_layers = num_encoder_layers - self.num_decoder_layers = num_decoder_layers - self.relative_attention_num_buckets = relative_attention_num_buckets - self.realtive_attention_max_distance = relative_attention_max_distance - self.padding_idx = padding_idx - self.max_seq_len = max_seq_len - self.vocab_size = vocab_size + assert isinstance(config, T5Conf) + + self.encoder_only = config.encoder_only + self.padding_idx = config.padding_idx + self.training = config.training + self.dropout = config.dropout if config.training else 0.0 self.device = device self.dtype = dtype - self.token_embeddings = nn.Embedding(vocab_size, d_model, padding_idx) + self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx) self.encoder = T5Stack( is_decoder=False, - d_model=d_model, - nhead=nhead, - num_layers=num_encoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, + d_model=config.embedding_dim, + nhead=config.num_attention_heads, + num_layers=config.num_encoder_layers, + dim_feedforward=config.ffn_dimension, + dropout=self.dropout, + activation=config.activation, + layer_norm_eps=config.layer_norm_eps, + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, device=device, dtype=dtype, ) - self.norm1 = T5LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) + self.norm1 = T5LayerNorm(config.embedding_dim) + self.dropout1 = nn.Dropout(self.dropout) + self.dropout2 = nn.Dropout(self.dropout) - if not encoder_only: + if not config.encoder_only: self.decoder = T5Stack( is_decoder=True, - d_model=d_model, - nhead=nhead, - num_layers=num_decoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - relative_attention_num_buckets=relative_attention_num_buckets, - relative_attention_max_distance=relative_attention_max_distance, + d_model=config.embedding_dim, + nhead=config.num_attention_heads, + num_layers=config.num_decoder_layers, + dim_feedforward=config.ffn_dimension, + dropout=self.dropout, + activation=config.activation, + layer_norm_eps=config.layer_norm_eps, + relative_attention_num_buckets=config.relative_attention_num_buckets, + relative_attention_max_distance=config.relative_attention_max_distance, device=device, dtype=dtype, ) - self.norm2 = T5LayerNorm(d_model) - self.dropout3 = nn.Dropout(dropout) - self.dropout4 = nn.Dropout(dropout) + self.norm2 = T5LayerNorm(config.embedding_dim) + self.dropout3 = nn.Dropout(self.dropout) + self.dropout4 = nn.Dropout(self.dropout) + + if freeze: + for p in self.parameters(): + p.requires_grad = False def forward( self, From 04dba1cd83578eaaada41af4540725d2f25c5ad9 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 18:57:06 -0400 Subject: [PATCH 02/10] bundler API used to load pre-trained weight for t5 base model --- torchtext/prototype/t5/bundler.py | 124 ++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 torchtext/prototype/t5/bundler.py diff --git a/torchtext/prototype/t5/bundler.py b/torchtext/prototype/t5/bundler.py new file mode 100644 index 0000000000..b3285164fe --- /dev/null +++ b/torchtext/prototype/t5/bundler.py @@ -0,0 +1,124 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union +from urllib.parse import urljoin + +import torch +from torchtext import _TEXT_BUCKET +from torchtext._download_hooks import load_state_dict_from_url + +from .model import T5Conf, T5Model + +logger = logging.getLogger(__name__) + + +@dataclass +class T5Bundle: + _config: T5Conf + _path: Optional[str] = None + + def get_model( + self, + *, + load_weights: bool = True, + freeze_model: bool = False, + dl_kwargs: Dict[str, Any] = None, + ) -> T5Model: + + if load_weights: + assert ( + self._path is not None + ), "load_weights cannot be True. The pre-trained model weights are not available for the current object" + + if freeze_model: + if not load_weights or not self._path: + logger.warning( + "The model is not loaded with pre-trained weights. Setting freeze_model to True will hinder model from learning appropriate weights." + ) + + return T5Bundle.build_model( + config=self._config, + freeze_model=freeze_model, + checkpoint=self._path if load_weights else None, + strict=True, + dl_kwargs=dl_kwargs, + ) + + @classmethod + def build_model( + cls, + config: T5Conf, + *, + freeze_model: bool = False, + checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None, + strict=False, + dl_kwargs: Dict[str, Any] = None, + ) -> T5Model: + + model = T5Model(config, freeze_model) + if checkpoint is not None: + if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]): + state_dict = checkpoint + elif isinstance(checkpoint, str): + dl_kwargs = {} if dl_kwargs is None else dl_kwargs + state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs) + else: + raise TypeError( + "checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint)) + ) + + model.load_state_dict(state_dict, strict=strict) + + return model + + @property + def config(self) -> T5Conf: + return self._config + + +T5_BASE_ENCODER = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.encoder.pt"), + _config=T5Conf(encoder_only=True), +) + +T5_BASE_ENCODER.__doc__ = """ + T5 Encoder with Base configuration + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `. It introduces a unified framework that converts text-based + language problems, such as translation, question-answering, and summarization, into a text-to-text format. The + Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task, + and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version + of the canonical Transformer architecture. + + Originally published by the authors of T5 under Apache License, Version 2.0 + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.prototype.models.t5.T5Bundle` for the usage. + """ + + +T5_BASE = T5Bundle( + _path=urljoin(_TEXT_BUCKET, "t5.base.pt"), + _config=T5Conf(encoder_only=False), +) + +T5_BASE.__doc__ = """ + T5 Encoder-Decoder with Base configuration + + The T5 model was proposed in `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer + `. It introduces a unified framework that converts text-based + language problems, such as translation, question-answering, and summarization, into a text-to-text format. The + Colossal Clean Crawled Corpus (C4) dataset is used to pre-train the model on a masked language modeling task, + and various datasets are used to fine-tune the model on each downstream task. The model's architecture is a modified version + of the canonical Transformer architecture. + + Originally published by the authors of T5 under Apache License, Version 2.0 + and redistributed with the same license. + [`License `__, + `Source `__] + + Please refer to :func:`torchtext.prototype.models.t5.T5Bundle` for the usage. + """ From 0548feb43a9beefe87c7b5fbf9cf59fe93eaf48c Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:15:50 -0400 Subject: [PATCH 03/10] adding docstring to bundler --- torchtext/prototype/t5/bundler.py | 32 +++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/torchtext/prototype/t5/bundler.py b/torchtext/prototype/t5/bundler.py index b3285164fe..0df7f5b2a9 100644 --- a/torchtext/prototype/t5/bundler.py +++ b/torchtext/prototype/t5/bundler.py @@ -14,6 +14,34 @@ @dataclass class T5Bundle: + """T5Bundle(_config: torchtext.prototype.models.T5Conf, _path: Optional[str] = None) + + Example - Pretrained base t5 encoder + >>> import torch, torchtext + >>> t5_encoder_base = torchtext.prototype.models.T5_BASE_ENCODER + >>> model = t5_encoder_base.get_model() + >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> output = model(model_input)['encoder_output'] + >>> output.shape + torch.Size([2, 6, 768]) + + Example - Pretrained base t5 model + >>> import torch, torchtext + >>> t5_base = torchtext.prototype.models.T5_BASE + >>> model = t5_base.get_model() + >>> encoder_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> decoder_input = torch.tensor([[0],[0]]) + >>> output = model(encoder_input, decoder_input)['decoder_output'] + >>> output.shape + torch.Size([2, 1, 768]) + + Example - User-specified configuration and checkpoint + >>> from torchtext.prototype.models import T5Conf, T5Bundle + >>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt" + >>> encoder_conf = T5Conf(encoder_only=True) + >>> model = T5Bundle.build_model(encoder_conf=encoder_conf, checkpoint=model_weights_path) + """ + _config: T5Conf _path: Optional[str] = None @@ -96,7 +124,7 @@ def config(self) -> T5Conf: [`License `__, `Source `__] - Please refer to :func:`torchtext.prototype.models.t5.T5Bundle` for the usage. + Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. """ @@ -120,5 +148,5 @@ def config(self) -> T5Conf: [`License `__, `Source `__] - Please refer to :func:`torchtext.prototype.models.t5.T5Bundle` for the usage. + Please refer to :func:`torchtext.prototype.models.T5Bundle` for the usage. """ From 3603aa50dc71c1fb07da29cc6fff591e233e8dd8 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:16:57 -0400 Subject: [PATCH 04/10] __init__ for t5 imports --- torchtext/prototype/t5/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 torchtext/prototype/t5/__init__.py diff --git a/torchtext/prototype/t5/__init__.py b/torchtext/prototype/t5/__init__.py new file mode 100644 index 0000000000..f69829494d --- /dev/null +++ b/torchtext/prototype/t5/__init__.py @@ -0,0 +1,14 @@ +from .bundler import ( + T5_BASE_ENCODER, + T5_BASE, + T5Bundle, +) +from .model import T5Conf, T5Model + +__all__ = [ + "T5Conf", + "T5Model", + "T5Bundle", + "T5_BASE_ENCODER", + "T5_BASE", +] From f5128a0dedcef471fb97f6c886561a39e56ac5ac Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:19:47 -0400 Subject: [PATCH 05/10] moving t5 under prototype/models --- torchtext/prototype/{ => models}/t5/__init__.py | 0 torchtext/prototype/{ => models}/t5/bundler.py | 0 torchtext/prototype/{ => models}/t5/model.py | 0 torchtext/prototype/{ => models}/t5/modules.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename torchtext/prototype/{ => models}/t5/__init__.py (100%) rename torchtext/prototype/{ => models}/t5/bundler.py (100%) rename torchtext/prototype/{ => models}/t5/model.py (100%) rename torchtext/prototype/{ => models}/t5/modules.py (100%) diff --git a/torchtext/prototype/t5/__init__.py b/torchtext/prototype/models/t5/__init__.py similarity index 100% rename from torchtext/prototype/t5/__init__.py rename to torchtext/prototype/models/t5/__init__.py diff --git a/torchtext/prototype/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py similarity index 100% rename from torchtext/prototype/t5/bundler.py rename to torchtext/prototype/models/t5/bundler.py diff --git a/torchtext/prototype/t5/model.py b/torchtext/prototype/models/t5/model.py similarity index 100% rename from torchtext/prototype/t5/model.py rename to torchtext/prototype/models/t5/model.py diff --git a/torchtext/prototype/t5/modules.py b/torchtext/prototype/models/t5/modules.py similarity index 100% rename from torchtext/prototype/t5/modules.py rename to torchtext/prototype/models/t5/modules.py From 61ba40fe1285f0f2cfb62b2e74b7d96a6c8773d0 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:22:50 -0400 Subject: [PATCH 06/10] __init__ under prototype/models to allow for imports --- torchtext/prototype/models/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 torchtext/prototype/models/__init__.py diff --git a/torchtext/prototype/models/__init__.py b/torchtext/prototype/models/__init__.py new file mode 100644 index 0000000000..ab659dda3d --- /dev/null +++ b/torchtext/prototype/models/__init__.py @@ -0,0 +1 @@ +from .t5 import * # noqa: F401, F403 From 4aa3477f8f25f3780eb2d820efb1c8953eb3d0ee Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:27:34 -0400 Subject: [PATCH 07/10] correct typo in bundler docstring --- torchtext/prototype/models/t5/bundler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 0df7f5b2a9..8a0b8a192b 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -39,7 +39,7 @@ class T5Bundle: >>> from torchtext.prototype.models import T5Conf, T5Bundle >>> model_weights_path = "https://download.pytorch.org/models/text/t5.base.encoder.pt" >>> encoder_conf = T5Conf(encoder_only=True) - >>> model = T5Bundle.build_model(encoder_conf=encoder_conf, checkpoint=model_weights_path) + >>> model = T5Bundle.build_model(config=encoder_conf, checkpoint=model_weights_path) """ _config: T5Conf From 8909a7b30bee9de44886a5e3e15f3ab122d6ee7b Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Mon, 18 Jul 2022 19:46:52 -0400 Subject: [PATCH 08/10] updating T5Model docstring --- torchtext/prototype/models/t5/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index d367ed4805..87512c7f70 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -52,12 +52,14 @@ class T5Model(nn.Module): config.max_seq_len: Maximum sequence length (default: 512) config.vocab_size: Size of vocabulary (default: 32128) config.training: Whether or not to apply dropout (default: False) - freeze: - Examples:: - >>> t5_model = T5Model(encoder_only=False) - >>> src = torch.rand((32, 10, 512)) - >>> tgt = torch.rand((32, 20, 512)) - >>> out = t5_model(src, tgt) + freeze: Indicates whether or not to freeze the model weights. (default: False) + Examples: + from torchtext.prototype.models import T5Conf, T5Model + >>> t5_config = T5Conf(encoder_only=False) + >>> t5_model = T5Model(t5_config) + >>> encoder_input = torch.rand((32, 10, 512)) + >>> decoder_input = torch.rand((32, 20, 512)) + >>> out = t5_model(encoder_input, decoder_input) """ def __init__( From e92b65fdf20f6dafd4e57223f1d251edb74cabe7 Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 19 Jul 2022 10:38:11 -0400 Subject: [PATCH 09/10] add docstrings for bundler methods --- torchtext/prototype/models/t5/bundler.py | 16 ++++++++++++++++ torchtext/prototype/models/t5/model.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 8a0b8a192b..4cf9362486 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -52,6 +52,13 @@ def get_model( freeze_model: bool = False, dl_kwargs: Dict[str, Any] = None, ) -> T5Model: + r"""get_model(load_weights: bool = True, freeze_model: bool = False, *, dl_kwargs=None) -> torctext.prototype.models.T5Model + + Args: + load_weights (bool): Indicates whether or not to load weights if available. (Default: `True`) + freeze_model (bool): Indicates whether or not to freeze the model weights. (Default: `False`) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) + """ if load_weights: assert ( @@ -82,6 +89,15 @@ def build_model( strict=False, dl_kwargs: Dict[str, Any] = None, ) -> T5Model: + """Class builder method + + Args: + config (T5Conf): An instance of classT5Conf that defined the model configuration + freeze_model (bool): Indicates whether to freeze the model weights. (Default: `False`) + checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``) + strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: `False`) + dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: `None`) + """ model = T5Model(config, freeze_model) if checkpoint is not None: diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 87512c7f70..341919393f 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -54,7 +54,7 @@ class T5Model(nn.Module): config.training: Whether or not to apply dropout (default: False) freeze: Indicates whether or not to freeze the model weights. (default: False) Examples: - from torchtext.prototype.models import T5Conf, T5Model + >>> from torchtext.prototype.models import T5Conf, T5Model >>> t5_config = T5Conf(encoder_only=False) >>> t5_model = T5Model(t5_config) >>> encoder_input = torch.rand((32, 10, 512)) From b16038346da4bda52975f8c0384aba0c2789240d Mon Sep 17 00:00:00 2001 From: pmabbo13 Date: Tue, 19 Jul 2022 12:01:49 -0400 Subject: [PATCH 10/10] initialize decoder input sequence to be padding index without requiring user input --- torchtext/prototype/models/t5/bundler.py | 5 ++--- torchtext/prototype/models/t5/model.py | 9 +++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/torchtext/prototype/models/t5/bundler.py b/torchtext/prototype/models/t5/bundler.py index 4cf9362486..ed2262db82 100644 --- a/torchtext/prototype/models/t5/bundler.py +++ b/torchtext/prototype/models/t5/bundler.py @@ -29,9 +29,8 @@ class T5Bundle: >>> import torch, torchtext >>> t5_base = torchtext.prototype.models.T5_BASE >>> model = t5_base.get_model() - >>> encoder_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) - >>> decoder_input = torch.tensor([[0],[0]]) - >>> output = model(encoder_input, decoder_input)['decoder_output'] + >>> model_input = torch.tensor([[1,2,3,4,5,6],[7,8,9,0,0,0]]) + >>> output = model(model_input)['decoder_output'] >>> output.shape torch.Size([2, 1, 768]) diff --git a/torchtext/prototype/models/t5/model.py b/torchtext/prototype/models/t5/model.py index 341919393f..41ffb12b4e 100644 --- a/torchtext/prototype/models/t5/model.py +++ b/torchtext/prototype/models/t5/model.py @@ -136,7 +136,8 @@ def forward( encoder input sequence length. (required). decoder_tokens: Tokenized input sequence to the decoder. Must be batch first with shape (B, Nd) where B is the batch size and Nd is the - decoder input sequence length. (required). + decoder input sequence length. If None and model is encoder-decoder, will initialize decoder + input sequence to begin with padding index. (optional). encoder_mask: Self-attention mask for the encoder input sequence. Must have shape (Ne, Ne) (optional). decoder_mask: Self-attention mask for the decoder input sequence. @@ -163,7 +164,11 @@ def forward( encoder_hidden_states = encoder_hidden_states + (encoder_output,) if not self.encoder_only: - assert decoder_tokens is not None + + # decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx. + if decoder_tokens is None: + decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx + if decoder_mask is None: tgt_len = decoder_tokens.shape[1] decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()