-
Notifications
You must be signed in to change notification settings - Fork 814
Updating T5 demo to use beam search for generator #1869
Conversation
| from torch import Tensor | ||
| from torchtext.prototype.models import T5Model | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to note is that when generating the first tokens of the sequences, decoder_tokens has shape (batch_size, 1). Since we are using a beam search, at that first iteration each sequence has k many tokens with which to start the sequence (since we are choosing the top k tokens to expand a given sequence by). Since the decoder expects decoder_tokens to be 2D, with one sequence per row, we treat each beam as its own sequence such that decoder_tokens now has shape (batch_size * beam_size, 2), where the first k rows are the beams belonging to the original sequence 1, the next k rows are the beams belonging to the original sequence 2, etc.
Since the decoder also requires the encoder outputs as an input argument, we must also now reshape the encoder output since it only has the output for a batch_size number of sequence. Lines 239-244 define a new order where we repeat each encoder output related to each given original sequence k times. This means that along dim=0, the first k indices contain the encoder output for the original sequence 1, the next k for original sequence 2, etc. This is so that that as we pass in each beam as its own sequence in decoder_tokens, the decoder has the correct corresponding encoder output for the sequence.
| ) -> Tensor: | ||
| def beam_search( | ||
| beam_size: int, | ||
| step: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
step here is equivalent to the current length of sequences in decoder_tokens. The first time beam_search is called, step=1 because decoder_tokens is initialized to have the padding token as the starter token to each sequence.
parmeet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Description
Update T5 tutorial to use a beam search for decoding, as opposed to a greedy search.
Process
A beam search was implemented, which keeps track of the log probability of multiple sequences generated for a single input sequence, and prunes them at each iteration to only keep the top
kmost likely. This allows the generator to create sequences that are more probable than those generated by a greedy search.We find that the sequences generated via beam search tend to be longer, but still mostly capture the main points expressed in the target summaries. We expect an improvement in the summaries with the addition of constraints such as a length penalty, ngram limit, min length, max length, etc.
Testing
We tested the logic be ascertaining that the sequences generated when
beam_size=1were the same as those generated under a greedy decoder --> see Generate Summaries section of this notebookRun
BUILD_GALLERY=1 make 'SPHINXOPTS=-W' htmlindocsand review rendered document indocs/build/html/tutorials/cnndm_summarization.html