Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@pmabbo13
Copy link
Contributor

@pmabbo13 pmabbo13 commented Aug 2, 2022

Description

Update T5 tutorial to use a beam search for decoding, as opposed to a greedy search.

Process

A beam search was implemented, which keeps track of the log probability of multiple sequences generated for a single input sequence, and prunes them at each iteration to only keep the top k most likely. This allows the generator to create sequences that are more probable than those generated by a greedy search.

We find that the sequences generated via beam search tend to be longer, but still mostly capture the main points expressed in the target summaries. We expect an improvement in the summaries with the addition of constraints such as a length penalty, ngram limit, min length, max length, etc.

Testing

We tested the logic be ascertaining that the sequences generated when beam_size=1 were the same as those generated under a greedy decoder --> see Generate Summaries section of this notebook

Run BUILD_GALLERY=1 make 'SPHINXOPTS=-W' html in docs and review rendered document in docs/build/html/tutorials/cnndm_summarization.html

@pmabbo13 pmabbo13 changed the title updating demo to use beam search for generator Updating T5 demo to use beam search for generator Aug 2, 2022
from torch import Tensor
from torchtext.prototype.models import T5Model


Copy link
Contributor Author

@pmabbo13 pmabbo13 Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that when generating the first tokens of the sequences, decoder_tokens has shape (batch_size, 1). Since we are using a beam search, at that first iteration each sequence has k many tokens with which to start the sequence (since we are choosing the top k tokens to expand a given sequence by). Since the decoder expects decoder_tokens to be 2D, with one sequence per row, we treat each beam as its own sequence such that decoder_tokens now has shape (batch_size * beam_size, 2), where the first k rows are the beams belonging to the original sequence 1, the next k rows are the beams belonging to the original sequence 2, etc.

Since the decoder also requires the encoder outputs as an input argument, we must also now reshape the encoder output since it only has the output for a batch_size number of sequence. Lines 239-244 define a new order where we repeat each encoder output related to each given original sequence k times. This means that along dim=0, the first k indices contain the encoder output for the original sequence 1, the next k for original sequence 2, etc. This is so that that as we pass in each beam as its own sequence in decoder_tokens, the decoder has the correct corresponding encoder output for the sequence.

) -> Tensor:
def beam_search(
beam_size: int,
step: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

step here is equivalent to the current length of sequences in decoder_tokens. The first time beam_search is called, step=1 because decoder_tokens is initialized to have the padding token as the starter token to each sequence.

Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@pmabbo13 pmabbo13 merged commit 2bb2562 into pytorch:main Aug 3, 2022
@pmabbo13 pmabbo13 deleted the feature/t5-beam-search branch August 3, 2022 21:10
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants