This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 814
Updating T5 demo to use beam search for generator #1869
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -149,21 +149,78 @@ def apply_prefix(task, x): | |
| # | ||
| # We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the | ||
| # model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated | ||
| # for all sequences in the batch. The `greedy_generator` method shown below uses a greedy search (i.e. expands the sequence | ||
| # based on the most probable next word). | ||
| # for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger | ||
| # beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to | ||
| # a greedy decoder. | ||
| # | ||
|
|
||
| from torch import Tensor | ||
| from torchtext.prototype.models import T5Model | ||
|
|
||
|
|
||
| def greedy_generator( | ||
| encoder_tokens: Tensor, | ||
| eos_idx: int, | ||
| model: T5Model, | ||
| ) -> Tensor: | ||
| def beam_search( | ||
| beam_size: int, | ||
| step: int, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| bsz: int, | ||
| decoder_output: Tensor, | ||
| decoder_tokens: Tensor, | ||
| scores: Tensor, | ||
| incomplete_sentences: Tensor, | ||
| ): | ||
| probs = F.log_softmax(decoder_output[:, -1], dim=-1) | ||
| top = torch.topk(probs, beam_size) | ||
|
|
||
| # N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size | ||
| # decoder_tokens has shape (N,L) -> (N,B,L) | ||
| # top.indices has shape (N,B) - > (N,B,1) | ||
| # x has shape (N,B,L+1) | ||
| # note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size | ||
| x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1) | ||
|
|
||
| # beams are first created for a given sequence | ||
| if step == 1: | ||
| # x has shape (batch_size, B, L+1) -> (batch_size * B, L+1) | ||
| # new_scores has shape (batch_size,B) | ||
| # incomplete_sentences has shape (batch_size * B) = (N) | ||
| new_decoder_tokens = x.view(-1, step + 1) | ||
| new_scores = top.values | ||
| new_incomplete_sentences = incomplete_sentences | ||
|
|
||
| # beams already exist, want to expand each beam into possible new tokens to add | ||
| # and for all expanded beams beloning to the same sequences, choose the top k | ||
| else: | ||
| # scores has shape (batch_size,B) -> (N,1) -> (N,B) | ||
| # top.values has shape (N,B) | ||
| # new_scores has shape (N,B) -> (batch_size, B^2) | ||
| new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1) | ||
|
|
||
| # v, i have shapes (batch_size, B) | ||
| v, i = torch.topk(new_scores, beam_size) | ||
|
|
||
| # x has shape (N,B,L+1) -> (batch_size, B, L+1) | ||
| # i has shape (batch_size, B) -> (batch_size, B, L+1) | ||
| # new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L) | ||
| x = x.view(bsz, -1, step + 1) | ||
| new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1) | ||
|
|
||
| # need to update incomplete sentences in case one of the beams was kicked out | ||
| # y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2) | ||
| y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1) | ||
|
|
||
| # now can use i to extract those beams that were selected | ||
| # new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N | ||
| new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1) | ||
|
|
||
| # new_scores has shape (batch_size, B) | ||
| new_scores = v | ||
|
|
||
| return new_decoder_tokens, new_scores, new_incomplete_sentences | ||
|
|
||
|
|
||
| def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor: | ||
|
|
||
| # pass tokens through encoder | ||
| bsz = encoder_tokens.size(0) | ||
| encoder_padding_mask = encoder_tokens.eq(model.padding_idx) | ||
| encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens)) | ||
| encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0] | ||
|
|
@@ -172,14 +229,22 @@ def greedy_generator( | |
| encoder_output = model.dropout2(encoder_output) | ||
|
|
||
| # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence | ||
| decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * model.padding_idx | ||
| decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx | ||
| scores = torch.zeros((bsz, beam_size)) | ||
|
|
||
| # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet | ||
| incomplete_sentences = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) | ||
| incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long) | ||
|
|
||
| # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token | ||
| for step in range(model.config.max_seq_len): | ||
|
|
||
| if step == 1: | ||
| # duplicate and order encoder output so that each beam is treated as its own independent sequence | ||
| new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) | ||
| new_order = new_order.to(encoder_tokens.device).long() | ||
| encoder_output = encoder_output.index_select(0, new_order) | ||
| encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) | ||
|
|
||
| # causal mask and padding mask for decoder sequence | ||
| tgt_len = decoder_tokens.shape[1] | ||
| decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() | ||
|
|
@@ -203,39 +268,39 @@ def greedy_generator( | |
| decoder_output = decoder_output * (model.config.embedding_dim ** -0.5) | ||
| decoder_output = model.lm_head(decoder_output) | ||
|
|
||
| # greedy search for next token to add to sequence | ||
| probs = F.log_softmax(decoder_output[:, -1], dim=-1) | ||
| _, next_token = torch.topk(probs, 1) | ||
|
|
||
| # ignore next tokens for sentences that are already complete | ||
| next_token *= incomplete_sentences | ||
| decoder_tokens, scores, incomplete_sentences = beam_search( | ||
| beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences | ||
| ) | ||
| # ignore newest tokens for sentences that are already complete | ||
| decoder_tokens[:, -1] *= incomplete_sentences | ||
|
|
||
| # update incomplete_sentences to remove those that were just ended | ||
| incomplete_sentences = incomplete_sentences - (next_token == eos_idx).long() | ||
|
|
||
| # update decoder sequences to include new tokens | ||
| decoder_tokens = torch.cat((decoder_tokens, next_token), 1) | ||
| incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long() | ||
|
|
||
| # early stop if all sentences have been ended | ||
| if (incomplete_sentences == 0).all(): | ||
| break | ||
|
|
||
| # take most likely sequence | ||
| decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :] | ||
| return decoder_tokens | ||
|
|
||
|
|
||
| ####################################################################### | ||
| # Generate Summaries | ||
| # ------------------ | ||
| # | ||
| # Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set. | ||
| # Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set | ||
| # using a beam size of 3. | ||
| # | ||
|
|
||
| batch = next(iter(test_dataloader)) | ||
| input_text = batch["article"] | ||
| model_input = transform(input_text) | ||
| target = batch["abstract"] | ||
| beam_size = 3 | ||
|
|
||
| model_output = greedy_generator(model=model, encoder_tokens=model_input, eos_idx=eos_idx) | ||
| model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) | ||
| output_text = transform.decode(model_output.tolist()) | ||
|
|
||
| for i in range(batch_size): | ||
|
|
@@ -253,10 +318,10 @@ def greedy_generator( | |
| # | ||
| # Example 1: | ||
| # | ||
| # prediction: the Palestinians officially become the 123rd member of the international | ||
| # criminal court . the move gives the court jurisdiction over alleged crimes committed | ||
| # in the occupied Palestinian territory . the ICC opened a preliminary examination into | ||
| # the situation in the occupied territories . | ||
| # prediction: the Palestinians become the 123rd member of the international criminal | ||
| # court . the accession was marked by a ceremony at the Hague, where the court is based . | ||
| # the ICC opened a preliminary examination into the situation in the occupied | ||
| # Palestinian territory . | ||
| # | ||
| # target: Membership gives the ICC jurisdiction over alleged crimes committed in | ||
| # Palestinian territories since last June . Israel and the United States opposed the | ||
|
|
@@ -265,10 +330,10 @@ def greedy_generator( | |
| # | ||
| # Example 2: | ||
| # | ||
| # prediction: a stray pooch in Washington state has used up at least three of her own | ||
| # after being hit by a car . the dog staggers to a nearby farm, dirt-covered and | ||
| # emaciated, where she is found . she suffered a dislocated jaw, leg injuries and a | ||
| # caved-in sinus cavity . | ||
| # prediction: a stray pooch has used up at least three of her own after being hit by a | ||
| # car and buried in a field . the dog managed to stagger to a nearby farm, dirt-covered | ||
| # and emaciated, where she was found . she suffered a dislocated jaw, leg injuries and a | ||
| # caved-in sinus cavity -- and still requires surgery to help her breathe . | ||
| # | ||
| # target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer | ||
| # and buried in a field . "She's a true miracle dog and she deserves a good life," says | ||
|
|
@@ -277,9 +342,9 @@ def greedy_generator( | |
| # | ||
| # Example 3: | ||
| # | ||
| # prediction: mohammad Javad Zarif is the foreign minister of the country . he has been | ||
| # a key figure in securing a breakthrough in nuclear talks . he has been a hero in the | ||
| # international community . | ||
| # prediction: mohammad Javad Zarif arrived in Iran on a sunny friday morning . he has gone | ||
| # a long way to bring Iran in from the cold and allow it to rejoin the international | ||
| # community . but there are some facts about him that are less well-known . | ||
| # | ||
| # target: Mohammad Javad Zarif has spent more time with John Kerry than any other | ||
| # foreign minister . He once participated in a takeover of the Iranian Consulate in San | ||
|
|
@@ -288,9 +353,9 @@ def greedy_generator( | |
| # | ||
| # Example 4: | ||
| # | ||
| # prediction: five americans were monitored for three weeks after being exposed to | ||
| # Ebola . one of the five had a heart-related issue on Saturday and has been discharged . | ||
| # none of the patients developed the deadly virus . | ||
| # prediction: five americans were monitored for three weeks after being exposed to Ebola in | ||
| # west africa . one of the five had a heart-related issue and has been discharged but hasn't | ||
| # left the area . they are clinicians for Partners in Health, a Boston-based aid group . | ||
| # | ||
| # target: 17 Americans were exposed to the Ebola virus while in Sierra Leone in March . | ||
| # Another person was diagnosed with the disease and taken to hospital in Maryland . | ||
|
|
@@ -302,7 +367,8 @@ def greedy_generator( | |
| # | ||
| # prediction: the student was identified during an investigation by campus police and | ||
| # the office of student affairs . he admitted to placing the noose on the tree early | ||
| # Wednesday morning . | ||
| # Wednesday morning . the incident is one of several recent racist events to affect | ||
| # college students . | ||
| # | ||
| # target: Student is no longer on Duke University campus and will face disciplinary | ||
| # review . School officials identified student during investigation and the person | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing to note is that when generating the first tokens of the sequences,
decoder_tokenshas shape(batch_size, 1). Since we are using a beam search, at that first iteration each sequence haskmany tokens with which to start the sequence (since we are choosing the topktokens to expand a given sequence by). Since the decoder expectsdecoder_tokensto be 2D, with one sequence per row, we treat each beam as its own sequence such thatdecoder_tokensnow has shape(batch_size * beam_size, 2), where the firstkrows are the beams belonging to the original sequence 1, the nextkrows are the beams belonging to the original sequence 2, etc.Since the decoder also requires the encoder outputs as an input argument, we must also now reshape the encoder output since it only has the output for a
batch_sizenumber of sequence. Lines239-244define a new order where we repeat each encoder output related to each given original sequencektimes. This means that alongdim=0, the firstkindices contain the encoder output for the original sequence 1, the nextkfor original sequence 2, etc. This is so that that as we pass in each beam as its own sequence indecoder_tokens, the decoder has the correct corresponding encoder output for the sequence.