Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions test/prototype/integration_tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from parameterized import parameterized
from test.common.assets import get_asset_path
from test.common.torchtext_test_case import TorchtextTestCase
from torchtext.prototype.models import (
Expand All @@ -9,7 +10,7 @@


class TestT5(TorchtextTestCase):
def _t5_model(self, t5_model, expected_asset_name, test_text):
def _t5_model(self, is_jit, t5_model, expected_asset_name, test_text):
"""Verify that pre-trained T5 models in torchtext produce
the same output as the HuggingFace reference implementation.
"""
Expand All @@ -18,6 +19,10 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
model = t5_model.get_model()
model = model.eval()

if is_jit:
transform = torch.jit.script(transform)
model = torch.jit.script(model)

model_input = transform(test_text)
if model.encoder_only:
actual = model(model_input)["encoder_output"]
Expand All @@ -27,17 +32,24 @@ def _t5_model(self, t5_model, expected_asset_name, test_text):
expected = torch.load(expected_asset_path)
torch.testing.assert_close(actual, expected, atol=1e-04, rtol=2.5e-06)

def test_t5_base_encoder_model(self) -> None:
@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_encoder_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.encoder.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text)
self._t5_model(
is_jit=is_jit, t5_model=T5_BASE_ENCODER, expected_asset_name=expected_asset_name, test_text=test_text
)

def test_t5_base_model(self) -> None:
@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)
self._t5_model(is_jit=is_jit, t5_model=T5_BASE, expected_asset_name=expected_asset_name, test_text=test_text)

def test_t5_base_generation_model(self) -> None:
@parameterized.expand([("jit", True), ("not_jit", False)])
def test_t5_base_generation_model(self, name, is_jit) -> None:
expected_asset_name = "t5.base.generation.output.pt"
test_text = ["Hello world", "Attention rocks!"]
self._t5_model(t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text)
self._t5_model(
is_jit=is_jit, t5_model=T5_BASE_GENERATION, expected_asset_name=expected_asset_name, test_text=test_text
)
40 changes: 26 additions & 14 deletions torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union, Callable
from typing import Dict, List, Optional, Union, Callable

import torch
import torch.nn as nn
from torch import Tensor

from .modules import T5Stack, T5LayerNorm
from .modules import T5Encoder, T5Decoder, T5LayerNorm


@dataclass
Expand Down Expand Up @@ -69,14 +69,15 @@ def __init__(
self,
config: T5Conf,
freeze: bool = False,
device=None,
device: Optional[torch.device] = None,
dtype=None,
) -> None:
super().__init__()

assert isinstance(config, T5Conf)

self.config = config
self.embedding_dim = config.embedding_dim
self.encoder_only = config.encoder_only
self.linear_head = config.linear_head
self.padding_idx = config.padding_idx
Expand All @@ -86,8 +87,7 @@ def __init__(
self.dtype = dtype

self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx)
self.encoder = T5Stack(
is_decoder=False,
self.encoder = T5Encoder(
d_model=config.embedding_dim,
nhead=config.num_attention_heads,
num_layers=config.num_encoder_layers,
Expand All @@ -105,8 +105,7 @@ def __init__(
self.dropout2 = nn.Dropout(self.dropout)

if not config.encoder_only:
self.decoder = T5Stack(
is_decoder=True,
self.decoder = T5Decoder(
d_model=config.embedding_dim,
nhead=config.num_attention_heads,
num_layers=config.num_decoder_layers,
Expand All @@ -122,9 +121,13 @@ def __init__(
self.norm2 = T5LayerNorm(config.embedding_dim)
self.dropout3 = nn.Dropout(self.dropout)
self.dropout4 = nn.Dropout(self.dropout)
else:
self.decoder = None

if config.linear_head:
self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
else:
self.lm_head = None

if freeze:
for p in self.parameters():
Expand All @@ -133,10 +136,10 @@ def __init__(
def forward(
self,
encoder_tokens: Tensor,
decoder_tokens: Tensor = None,
decoder_tokens: Optional[Tensor] = None,
encoder_mask: Optional[Tensor] = None,
decoder_mask: Optional[Tensor] = None,
) -> Dict[str, Union[Tensor, Tuple[Tensor]]]:
) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
encoder_tokens: Tokenized input sequence to the encoder.
Expand All @@ -163,23 +166,27 @@ def forward(
"""
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa, _ = self.encoder(
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder(
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
)

encoder_output = self.norm1(encoder_output)
encoder_output = self.dropout2(encoder_output)
encoder_hidden_states = encoder_hidden_states + (encoder_output,)
encoder_hidden_states.append(encoder_output)

if not self.encoder_only:

assert self.decoder is not None

# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
if decoder_tokens is None:
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx

if decoder_mask is None:
assert decoder_tokens is not None and decoder_tokens.dim() == 2
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1)
decoder_mask = decoder_mask.to(torch.bool)

decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
Expand All @@ -197,13 +204,14 @@ def forward(

decoder_output = self.norm2(decoder_output)
decoder_output = self.dropout4(decoder_output)
decoder_hidden_states = decoder_hidden_states + (decoder_output,)
decoder_hidden_states.append(decoder_output)

if self.linear_head:
assert self.lm_head is not None
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
# same word embeddings, which is always the case in our t5 implementation.
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
decoder_output = decoder_output * (self.config.embedding_dim ** -0.5)
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
decoder_output = self.lm_head(decoder_output)

t5_output = {
Expand All @@ -225,4 +233,8 @@ def forward(
"encoder_sa_scores": encoder_sa,
}

assert torch.jit.isinstance(
t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]
)

return t5_output
Loading