Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 2bb2562

Browse files
authored
Updating T5 demo to use beam search for generator (#1869)
* updating demo to use beam search for generator * details on beam size
1 parent 8eb0561 commit 2bb2562

File tree

1 file changed

+102
-36
lines changed

1 file changed

+102
-36
lines changed

examples/tutorials/cnndm_summarization.py

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -149,21 +149,78 @@ def apply_prefix(task, x):
149149
#
150150
# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
151151
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
152-
# for all sequences in the batch. The `greedy_generator` method shown below uses a greedy search (i.e. expands the sequence
153-
# based on the most probable next word).
152+
# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
153+
# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
154+
# a greedy decoder.
154155
#
155156

156157
from torch import Tensor
157158
from torchtext.prototype.models import T5Model
158159

159160

160-
def greedy_generator(
161-
encoder_tokens: Tensor,
162-
eos_idx: int,
163-
model: T5Model,
164-
) -> Tensor:
161+
def beam_search(
162+
beam_size: int,
163+
step: int,
164+
bsz: int,
165+
decoder_output: Tensor,
166+
decoder_tokens: Tensor,
167+
scores: Tensor,
168+
incomplete_sentences: Tensor,
169+
):
170+
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
171+
top = torch.topk(probs, beam_size)
172+
173+
# N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
174+
# decoder_tokens has shape (N,L) -> (N,B,L)
175+
# top.indices has shape (N,B) - > (N,B,1)
176+
# x has shape (N,B,L+1)
177+
# note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
178+
x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)
179+
180+
# beams are first created for a given sequence
181+
if step == 1:
182+
# x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
183+
# new_scores has shape (batch_size,B)
184+
# incomplete_sentences has shape (batch_size * B) = (N)
185+
new_decoder_tokens = x.view(-1, step + 1)
186+
new_scores = top.values
187+
new_incomplete_sentences = incomplete_sentences
188+
189+
# beams already exist, want to expand each beam into possible new tokens to add
190+
# and for all expanded beams beloning to the same sequences, choose the top k
191+
else:
192+
# scores has shape (batch_size,B) -> (N,1) -> (N,B)
193+
# top.values has shape (N,B)
194+
# new_scores has shape (N,B) -> (batch_size, B^2)
195+
new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)
196+
197+
# v, i have shapes (batch_size, B)
198+
v, i = torch.topk(new_scores, beam_size)
199+
200+
# x has shape (N,B,L+1) -> (batch_size, B, L+1)
201+
# i has shape (batch_size, B) -> (batch_size, B, L+1)
202+
# new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
203+
x = x.view(bsz, -1, step + 1)
204+
new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)
205+
206+
# need to update incomplete sentences in case one of the beams was kicked out
207+
# y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
208+
y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)
209+
210+
# now can use i to extract those beams that were selected
211+
# new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
212+
new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)
213+
214+
# new_scores has shape (batch_size, B)
215+
new_scores = v
216+
217+
return new_decoder_tokens, new_scores, new_incomplete_sentences
218+
219+
220+
def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:
165221

166222
# pass tokens through encoder
223+
bsz = encoder_tokens.size(0)
167224
encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
168225
encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
169226
encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]
@@ -172,14 +229,22 @@ def greedy_generator(
172229
encoder_output = model.dropout2(encoder_output)
173230

174231
# initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
175-
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * model.padding_idx
232+
decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
233+
scores = torch.zeros((bsz, beam_size))
176234

177235
# mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
178-
incomplete_sentences = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long)
236+
incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)
179237

180238
# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
181239
for step in range(model.config.max_seq_len):
182240

241+
if step == 1:
242+
# duplicate and order encoder output so that each beam is treated as its own independent sequence
243+
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
244+
new_order = new_order.to(encoder_tokens.device).long()
245+
encoder_output = encoder_output.index_select(0, new_order)
246+
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)
247+
183248
# causal mask and padding mask for decoder sequence
184249
tgt_len = decoder_tokens.shape[1]
185250
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
@@ -203,39 +268,39 @@ def greedy_generator(
203268
decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
204269
decoder_output = model.lm_head(decoder_output)
205270

206-
# greedy search for next token to add to sequence
207-
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
208-
_, next_token = torch.topk(probs, 1)
209-
210-
# ignore next tokens for sentences that are already complete
211-
next_token *= incomplete_sentences
271+
decoder_tokens, scores, incomplete_sentences = beam_search(
272+
beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
273+
)
274+
# ignore newest tokens for sentences that are already complete
275+
decoder_tokens[:, -1] *= incomplete_sentences
212276

213277
# update incomplete_sentences to remove those that were just ended
214-
incomplete_sentences = incomplete_sentences - (next_token == eos_idx).long()
215-
216-
# update decoder sequences to include new tokens
217-
decoder_tokens = torch.cat((decoder_tokens, next_token), 1)
278+
incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()
218279

219280
# early stop if all sentences have been ended
220281
if (incomplete_sentences == 0).all():
221282
break
222283

284+
# take most likely sequence
285+
decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
223286
return decoder_tokens
224287

225288

226289
#######################################################################
227290
# Generate Summaries
228291
# ------------------
229292
#
230-
# Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set.
293+
# Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
294+
# using a beam size of 3.
231295
#
232296

233297
batch = next(iter(test_dataloader))
234298
input_text = batch["article"]
235299
model_input = transform(input_text)
236300
target = batch["abstract"]
301+
beam_size = 3
237302

238-
model_output = greedy_generator(model=model, encoder_tokens=model_input, eos_idx=eos_idx)
303+
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
239304
output_text = transform.decode(model_output.tolist())
240305

241306
for i in range(batch_size):
@@ -253,10 +318,10 @@ def greedy_generator(
253318
#
254319
# Example 1:
255320
#
256-
# prediction: the Palestinians officially become the 123rd member of the international
257-
# criminal court . the move gives the court jurisdiction over alleged crimes committed
258-
# in the occupied Palestinian territory . the ICC opened a preliminary examination into
259-
# the situation in the occupied territories .
321+
# prediction: the Palestinians become the 123rd member of the international criminal
322+
# court . the accession was marked by a ceremony at the Hague, where the court is based .
323+
# the ICC opened a preliminary examination into the situation in the occupied
324+
# Palestinian territory .
260325
#
261326
# target: Membership gives the ICC jurisdiction over alleged crimes committed in
262327
# Palestinian territories since last June . Israel and the United States opposed the
@@ -265,10 +330,10 @@ def greedy_generator(
265330
#
266331
# Example 2:
267332
#
268-
# prediction: a stray pooch in Washington state has used up at least three of her own
269-
# after being hit by a car . the dog staggers to a nearby farm, dirt-covered and
270-
# emaciated, where she is found . she suffered a dislocated jaw, leg injuries and a
271-
# caved-in sinus cavity .
333+
# prediction: a stray pooch has used up at least three of her own after being hit by a
334+
# car and buried in a field . the dog managed to stagger to a nearby farm, dirt-covered
335+
# and emaciated, where she was found . she suffered a dislocated jaw, leg injuries and a
336+
# caved-in sinus cavity -- and still requires surgery to help her breathe .
272337
#
273338
# target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer
274339
# and buried in a field . "She's a true miracle dog and she deserves a good life," says
@@ -277,9 +342,9 @@ def greedy_generator(
277342
#
278343
# Example 3:
279344
#
280-
# prediction: mohammad Javad Zarif is the foreign minister of the country . he has been
281-
# a key figure in securing a breakthrough in nuclear talks . he has been a hero in the
282-
# international community .
345+
# prediction: mohammad Javad Zarif arrived in Iran on a sunny friday morning . he has gone
346+
# a long way to bring Iran in from the cold and allow it to rejoin the international
347+
# community . but there are some facts about him that are less well-known .
283348
#
284349
# target: Mohammad Javad Zarif has spent more time with John Kerry than any other
285350
# foreign minister . He once participated in a takeover of the Iranian Consulate in San
@@ -288,9 +353,9 @@ def greedy_generator(
288353
#
289354
# Example 4:
290355
#
291-
# prediction: five americans were monitored for three weeks after being exposed to
292-
# Ebola . one of the five had a heart-related issue on Saturday and has been discharged .
293-
# none of the patients developed the deadly virus .
356+
# prediction: five americans were monitored for three weeks after being exposed to Ebola in
357+
# west africa . one of the five had a heart-related issue and has been discharged but hasn't
358+
# left the area . they are clinicians for Partners in Health, a Boston-based aid group .
294359
#
295360
# target: 17 Americans were exposed to the Ebola virus while in Sierra Leone in March .
296361
# Another person was diagnosed with the disease and taken to hospital in Maryland .
@@ -302,7 +367,8 @@ def greedy_generator(
302367
#
303368
# prediction: the student was identified during an investigation by campus police and
304369
# the office of student affairs . he admitted to placing the noose on the tree early
305-
# Wednesday morning .
370+
# Wednesday morning . the incident is one of several recent racist events to affect
371+
# college students .
306372
#
307373
# target: Student is no longer on Duke University campus and will face disciplinary
308374
# review . School officials identified student during investigation and the person

0 commit comments

Comments
 (0)