Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 102 additions & 36 deletions examples/tutorials/cnndm_summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,78 @@ def apply_prefix(task, x):
#
# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
# for all sequences in the batch. The `greedy_generator` method shown below uses a greedy search (i.e. expands the sequence
# based on the most probable next word).
# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
# a greedy decoder.
#

from torch import Tensor
from torchtext.prototype.models import T5Model


Copy link
Contributor Author

@pmabbo13 pmabbo13 Aug 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that when generating the first tokens of the sequences, decoder_tokens has shape (batch_size, 1). Since we are using a beam search, at that first iteration each sequence has k many tokens with which to start the sequence (since we are choosing the top k tokens to expand a given sequence by). Since the decoder expects decoder_tokens to be 2D, with one sequence per row, we treat each beam as its own sequence such that decoder_tokens now has shape (batch_size * beam_size, 2), where the first k rows are the beams belonging to the original sequence 1, the next k rows are the beams belonging to the original sequence 2, etc.

Since the decoder also requires the encoder outputs as an input argument, we must also now reshape the encoder output since it only has the output for a batch_size number of sequence. Lines 239-244 define a new order where we repeat each encoder output related to each given original sequence k times. This means that along dim=0, the first k indices contain the encoder output for the original sequence 1, the next k for original sequence 2, etc. This is so that that as we pass in each beam as its own sequence in decoder_tokens, the decoder has the correct corresponding encoder output for the sequence.

def greedy_generator(
encoder_tokens: Tensor,
eos_idx: int,
model: T5Model,
) -> Tensor:
def beam_search(
beam_size: int,
step: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

step here is equivalent to the current length of sequences in decoder_tokens. The first time beam_search is called, step=1 because decoder_tokens is initialized to have the padding token as the starter token to each sequence.

bsz: int,
decoder_output: Tensor,
decoder_tokens: Tensor,
scores: Tensor,
incomplete_sentences: Tensor,
):
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
top = torch.topk(probs, beam_size)

# N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
# decoder_tokens has shape (N,L) -> (N,B,L)
# top.indices has shape (N,B) - > (N,B,1)
# x has shape (N,B,L+1)
# note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)

# beams are first created for a given sequence
if step == 1:
# x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
# new_scores has shape (batch_size,B)
# incomplete_sentences has shape (batch_size * B) = (N)
new_decoder_tokens = x.view(-1, step + 1)
new_scores = top.values
new_incomplete_sentences = incomplete_sentences

# beams already exist, want to expand each beam into possible new tokens to add
# and for all expanded beams beloning to the same sequences, choose the top k
else:
# scores has shape (batch_size,B) -> (N,1) -> (N,B)
# top.values has shape (N,B)
# new_scores has shape (N,B) -> (batch_size, B^2)
new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)

# v, i have shapes (batch_size, B)
v, i = torch.topk(new_scores, beam_size)

# x has shape (N,B,L+1) -> (batch_size, B, L+1)
# i has shape (batch_size, B) -> (batch_size, B, L+1)
# new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
x = x.view(bsz, -1, step + 1)
new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)

# need to update incomplete sentences in case one of the beams was kicked out
# y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)

# now can use i to extract those beams that were selected
# new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)

# new_scores has shape (batch_size, B)
new_scores = v

return new_decoder_tokens, new_scores, new_incomplete_sentences


def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:

# pass tokens through encoder
bsz = encoder_tokens.size(0)
encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]
Expand All @@ -172,14 +229,22 @@ def greedy_generator(
encoder_output = model.dropout2(encoder_output)

# initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * model.padding_idx
decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
scores = torch.zeros((bsz, beam_size))

# mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
incomplete_sentences = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long)
incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)

# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
for step in range(model.config.max_seq_len):

if step == 1:
# duplicate and order encoder output so that each beam is treated as its own independent sequence
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(encoder_tokens.device).long()
encoder_output = encoder_output.index_select(0, new_order)
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)

# causal mask and padding mask for decoder sequence
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
Expand All @@ -203,39 +268,39 @@ def greedy_generator(
decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
decoder_output = model.lm_head(decoder_output)

# greedy search for next token to add to sequence
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
_, next_token = torch.topk(probs, 1)

# ignore next tokens for sentences that are already complete
next_token *= incomplete_sentences
decoder_tokens, scores, incomplete_sentences = beam_search(
beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
)
# ignore newest tokens for sentences that are already complete
decoder_tokens[:, -1] *= incomplete_sentences

# update incomplete_sentences to remove those that were just ended
incomplete_sentences = incomplete_sentences - (next_token == eos_idx).long()

# update decoder sequences to include new tokens
decoder_tokens = torch.cat((decoder_tokens, next_token), 1)
incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()

# early stop if all sentences have been ended
if (incomplete_sentences == 0).all():
break

# take most likely sequence
decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
return decoder_tokens


#######################################################################
# Generate Summaries
# ------------------
#
# Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set.
# Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
# using a beam size of 3.
#

batch = next(iter(test_dataloader))
input_text = batch["article"]
model_input = transform(input_text)
target = batch["abstract"]
beam_size = 3

model_output = greedy_generator(model=model, encoder_tokens=model_input, eos_idx=eos_idx)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(batch_size):
Expand All @@ -253,10 +318,10 @@ def greedy_generator(
#
# Example 1:
#
# prediction: the Palestinians officially become the 123rd member of the international
# criminal court . the move gives the court jurisdiction over alleged crimes committed
# in the occupied Palestinian territory . the ICC opened a preliminary examination into
# the situation in the occupied territories .
# prediction: the Palestinians become the 123rd member of the international criminal
# court . the accession was marked by a ceremony at the Hague, where the court is based .
# the ICC opened a preliminary examination into the situation in the occupied
# Palestinian territory .
#
# target: Membership gives the ICC jurisdiction over alleged crimes committed in
# Palestinian territories since last June . Israel and the United States opposed the
Expand All @@ -265,10 +330,10 @@ def greedy_generator(
#
# Example 2:
#
# prediction: a stray pooch in Washington state has used up at least three of her own
# after being hit by a car . the dog staggers to a nearby farm, dirt-covered and
# emaciated, where she is found . she suffered a dislocated jaw, leg injuries and a
# caved-in sinus cavity .
# prediction: a stray pooch has used up at least three of her own after being hit by a
# car and buried in a field . the dog managed to stagger to a nearby farm, dirt-covered
# and emaciated, where she was found . she suffered a dislocated jaw, leg injuries and a
# caved-in sinus cavity -- and still requires surgery to help her breathe .
#
# target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer
# and buried in a field . "She's a true miracle dog and she deserves a good life," says
Expand All @@ -277,9 +342,9 @@ def greedy_generator(
#
# Example 3:
#
# prediction: mohammad Javad Zarif is the foreign minister of the country . he has been
# a key figure in securing a breakthrough in nuclear talks . he has been a hero in the
# international community .
# prediction: mohammad Javad Zarif arrived in Iran on a sunny friday morning . he has gone
# a long way to bring Iran in from the cold and allow it to rejoin the international
# community . but there are some facts about him that are less well-known .
#
# target: Mohammad Javad Zarif has spent more time with John Kerry than any other
# foreign minister . He once participated in a takeover of the Iranian Consulate in San
Expand All @@ -288,9 +353,9 @@ def greedy_generator(
#
# Example 4:
#
# prediction: five americans were monitored for three weeks after being exposed to
# Ebola . one of the five had a heart-related issue on Saturday and has been discharged .
# none of the patients developed the deadly virus .
# prediction: five americans were monitored for three weeks after being exposed to Ebola in
# west africa . one of the five had a heart-related issue and has been discharged but hasn't
# left the area . they are clinicians for Partners in Health, a Boston-based aid group .
#
# target: 17 Americans were exposed to the Ebola virus while in Sierra Leone in March .
# Another person was diagnosed with the disease and taken to hospital in Maryland .
Expand All @@ -302,7 +367,8 @@ def greedy_generator(
#
# prediction: the student was identified during an investigation by campus police and
# the office of student affairs . he admitted to placing the noose on the tree early
# Wednesday morning .
# Wednesday morning . the incident is one of several recent racist events to affect
# college students .
#
# target: Student is no longer on Duke University campus and will face disciplinary
# review . School officials identified student during investigation and the person
Expand Down