Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 5 additions & 9 deletions .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ commands:
- run:
name: adding UPLOAD_CHANNEL to BASH_ENV
command: |
our_upload_channel=nightly
# On tags upload to test instead
if [[ -n "${CIRCLE_TAG}" ]] || [[ ${CIRCLE_BRANCH} =~ release/* ]]; then
our_upload_channel=test
fi
echo "export UPLOAD_CHANNEL=${our_upload_channel}" >> ${BASH_ENV}
# hardcoded upload channel for release
echo "export UPLOAD_CHANNEL=test" >> ${BASH_ENV}
load_conda_channel_flags:
description: "Determines whether we need extra conda channels"
steps:
Expand All @@ -43,15 +39,15 @@ binary_common: &binary_common
build_version:
description: "version number of release binary; by default, build a nightly"
type: string
default: ""
default: "0.15.0"
pytorch_version:
description: "PyTorch version to build against; by default, use a nightly"
type: string
default: ""
default: "2.0.0"
torchdata_version:
description: "TorchData version to build against; by default, use a nightly"
type: string
default: ""
default: "0.6.0"
# Don't edit these
python_version:
description: "Python version to build against (e.g., 3.8)"
Expand Down
164 changes: 18 additions & 146 deletions examples/tutorials/t5_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
==========================================================================

**Author**: `Pendo Abbo <[email protected]>`__
**Author**: `Joe Cummings <[email protected]>`__

"""

Expand All @@ -24,7 +25,6 @@
# Common imports
# --------------
import torch
import torch.nn.functional as F

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand All @@ -47,7 +47,7 @@
# the T5 model expects the input to be batched.
#

from torchtext.prototype.models import T5Transform
from torchtext.models import T5Transform

padding_idx = 0
eos_idx = 1
Expand All @@ -66,7 +66,7 @@
#
# ::
#
# from torchtext.prototype.models import T5_BASE_GENERATION
# from torchtext.models import T5_BASE_GENERATION
# transform = T5_BASE_GENERATION.transform()
#

Expand All @@ -81,7 +81,7 @@
# https://pytorch.org/text/main/models.html
#
#
from torchtext.prototype.models import T5_BASE_GENERATION
from torchtext.models import T5_BASE_GENERATION


t5_base = T5_BASE_GENERATION
Expand All @@ -92,146 +92,18 @@


#######################################################################
# Sequence Generator
# GenerationUtils
# ------------------
#
# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
# a greedy decoder.
#

from torch import Tensor
from torchtext.prototype.models import T5Model


def beam_search(
beam_size: int,
step: int,
bsz: int,
decoder_output: Tensor,
decoder_tokens: Tensor,
scores: Tensor,
incomplete_sentences: Tensor,
):
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
top = torch.topk(probs, beam_size)

# N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
# decoder_tokens has shape (N,L) -> (N,B,L)
# top.indices has shape (N,B) - > (N,B,1)
# x has shape (N,B,L+1)
# note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)

# beams are first created for a given sequence
if step == 1:
# x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
# new_scores has shape (batch_size,B)
# incomplete_sentences has shape (batch_size * B) = (N)
new_decoder_tokens = x.view(-1, step + 1)
new_scores = top.values
new_incomplete_sentences = incomplete_sentences

# beams already exist, want to expand each beam into possible new tokens to add
# and for all expanded beams beloning to the same sequences, choose the top k
else:
# scores has shape (batch_size,B) -> (N,1) -> (N,B)
# top.values has shape (N,B)
# new_scores has shape (N,B) -> (batch_size, B^2)
new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)

# v, i have shapes (batch_size, B)
v, i = torch.topk(new_scores, beam_size)

# x has shape (N,B,L+1) -> (batch_size, B, L+1)
# i has shape (batch_size, B) -> (batch_size, B, L+1)
# new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
x = x.view(bsz, -1, step + 1)
new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)

# need to update incomplete sentences in case one of the beams was kicked out
# y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)

# now can use i to extract those beams that were selected
# new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)

# new_scores has shape (batch_size, B)
new_scores = v

return new_decoder_tokens, new_scores, new_incomplete_sentences


def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:

# pass tokens through encoder
bsz = encoder_tokens.size(0)
encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]

encoder_output = model.norm1(encoder_output)
encoder_output = model.dropout2(encoder_output)

# initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
scores = torch.zeros((bsz, beam_size))

# mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)

# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
for step in range(model.config.max_seq_len):

if step == 1:
# duplicate and order encoder output so that each beam is treated as its own independent sequence
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(encoder_tokens.device).long()
encoder_output = encoder_output.index_select(0, new_order)
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)

# causal mask and padding mask for decoder sequence
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
decoder_padding_mask = decoder_tokens.eq(model.padding_idx)

# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
decoder_padding_mask[:, 0] = False

# pass decoder sequence through decoder
decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens))
decoder_output = model.decoder(
decoder_embeddings,
memory=encoder_output,
tgt_mask=decoder_mask,
tgt_key_padding_mask=decoder_padding_mask,
memory_key_padding_mask=encoder_padding_mask,
)[0]

decoder_output = model.norm2(decoder_output)
decoder_output = model.dropout4(decoder_output)
decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
decoder_output = model.lm_head(decoder_output)

decoder_tokens, scores, incomplete_sentences = beam_search(
beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
)
# ignore newest tokens for sentences that are already complete
decoder_tokens[:, -1] *= incomplete_sentences

# update incomplete_sentences to remove those that were just ended
incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()

# early stop if all sentences have been ended
if (incomplete_sentences == 0).all():
break

# take most likely sequence
decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
return decoder_tokens
# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and
# other decoding strategies are also supported.
#
#
from torchtext.prototype.generate import GenerationUtils

sequence_generator = GenerationUtils(model)


#######################################################################
Expand Down Expand Up @@ -343,16 +215,16 @@ def process_labels(labels, x):
# ------------------
#
# We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
# using a beam size of 3.
# using a beam size of 1.
#

batch = next(iter(cnndm_dataloader))
input_text = batch["article"]
target = batch["abstract"]
beam_size = 3
beam_size = 1

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(cnndm_batch_size):
Expand Down Expand Up @@ -442,7 +314,7 @@ def process_labels(labels, x):
beam_size = 1

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(imdb_batch_size):
Expand Down Expand Up @@ -536,7 +408,7 @@ def process_labels(labels, x):
beam_size = 4

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(multi_batch_size):
Expand Down
4 changes: 2 additions & 2 deletions packaging/install_torchdata.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash
package_type="$PACKAGE_TYPE"
channel="$CHANNEl"
channel="test"
if [ -z "$package_type" ]; then
package_type="wheel"
fi
if [ -z "$channel" ]; then
channel="nightly"
channel="test"
fi

# Wrong values
Expand Down
10 changes: 5 additions & 5 deletions packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ setup_pip_pytorch_version() {
if [[ -z "$PYTORCH_VERSION" ]]; then
# Install latest prerelease version of torch, per our nightlies, consistent
# with the requested cuda version
pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html"
pip_install --pre torch -f "https://download.pytorch.org/whl/test/${WHEEL_DIR}torch_test.html"
# CUDA and CPU are ABI compatible on the CPU-only parts, so strip
# in this case
export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')"
Expand All @@ -191,7 +191,7 @@ setup_pip_pytorch_version() {
-f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/torch_${UPLOAD_CHANNEL}.html"
fi
if [[ -z "$TORCHDATA_VERSION" ]]; then
pip_install --pre torchdata -f "https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"
pip_install --pre torchdata -f "https://download.pytorch.org/whl/test/cpu/torch_test.html"
export TORCHDATA_VERSION="$(pip show torchdata | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')"
else
pip_install "torchdata==$TORCHDATA_VERSION" \
Expand All @@ -207,13 +207,13 @@ setup_pip_pytorch_version() {
setup_conda_pytorch_constraint() {
CONDA_CHANNEL_FLAGS=${CONDA_CHANNEL_FLAGS:-}
if [[ -z "$PYTORCH_VERSION" ]]; then
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly"
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-${UPLOAD_CHANNEL}"
PYTHON="python"
# Check if we have python 3 instead and prefer that
if python3 --version >/dev/null 2>/dev/null; then
PYTHON="python3"
fi
export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | ${PYTHON} -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")"
export PYTORCH_VERSION="$(conda search --json pytorch[channel=pytorch-${UPLOAD_CHANNEL}] | ${PYTHON} -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")"
else
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}"
fi
Expand All @@ -233,7 +233,7 @@ setup_conda_pytorch_constraint() {
fi
fi
if [[ -z "$TORCHDATA_VERSION" ]]; then
export TORCHDATA_VERSION="$(conda search --json 'torchdata[channel=pytorch-nightly]' | ${PYTHON} -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['torchdata'][-1]['version']))")"
export TORCHDATA_VERSION="$(conda search --json 'torchdata[channel=pytorch-test]' | ${PYTHON} -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['torchdata'][-1]['version']))")"
fi
export CONDA_TORCHDATA_CONSTRAINT="- torchdata==$TORCHDATA_VERSION"
}
Expand Down
2 changes: 1 addition & 1 deletion test/smoke_tests/smoke_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ def validateTorchdataVersion():
raise RuntimeError(f"torchdata binary {torchdata.__version__} is more than {NIGHTLY_ALLOWED_DELTA} days old!")


validateTorchdataVersion()
# validateTorchdataVersion()
print("torchtext version is ", torchtext.__version__)