@@ -149,21 +149,78 @@ def apply_prefix(task, x):
149149#
150150# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
151151# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
152- # for all sequences in the batch. The `greedy_generator` method shown below uses a greedy search (i.e. expands the sequence
153- # based on the most probable next word).
152+ # for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
153+ # beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
154+ # a greedy decoder.
154155#
155156
156157from torch import Tensor
157158from torchtext .prototype .models import T5Model
158159
159160
160- def greedy_generator (
161- encoder_tokens : Tensor ,
162- eos_idx : int ,
163- model : T5Model ,
164- ) -> Tensor :
161+ def beam_search (
162+ beam_size : int ,
163+ step : int ,
164+ bsz : int ,
165+ decoder_output : Tensor ,
166+ decoder_tokens : Tensor ,
167+ scores : Tensor ,
168+ incomplete_sentences : Tensor ,
169+ ):
170+ probs = F .log_softmax (decoder_output [:, - 1 ], dim = - 1 )
171+ top = torch .topk (probs , beam_size )
172+
173+ # N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
174+ # decoder_tokens has shape (N,L) -> (N,B,L)
175+ # top.indices has shape (N,B) - > (N,B,1)
176+ # x has shape (N,B,L+1)
177+ # note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
178+ x = torch .cat ([decoder_tokens .unsqueeze (1 ).repeat (1 , beam_size , 1 ), top .indices .unsqueeze (- 1 )], dim = - 1 )
179+
180+ # beams are first created for a given sequence
181+ if step == 1 :
182+ # x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
183+ # new_scores has shape (batch_size,B)
184+ # incomplete_sentences has shape (batch_size * B) = (N)
185+ new_decoder_tokens = x .view (- 1 , step + 1 )
186+ new_scores = top .values
187+ new_incomplete_sentences = incomplete_sentences
188+
189+ # beams already exist, want to expand each beam into possible new tokens to add
190+ # and for all expanded beams beloning to the same sequences, choose the top k
191+ else :
192+ # scores has shape (batch_size,B) -> (N,1) -> (N,B)
193+ # top.values has shape (N,B)
194+ # new_scores has shape (N,B) -> (batch_size, B^2)
195+ new_scores = (scores .view (- 1 , 1 ).repeat (1 , beam_size ) + top .values ).view (bsz , - 1 )
196+
197+ # v, i have shapes (batch_size, B)
198+ v , i = torch .topk (new_scores , beam_size )
199+
200+ # x has shape (N,B,L+1) -> (batch_size, B, L+1)
201+ # i has shape (batch_size, B) -> (batch_size, B, L+1)
202+ # new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
203+ x = x .view (bsz , - 1 , step + 1 )
204+ new_decoder_tokens = x .gather (index = i .unsqueeze (- 1 ).repeat (1 , 1 , step + 1 ), dim = 1 ).view (- 1 , step + 1 )
205+
206+ # need to update incomplete sentences in case one of the beams was kicked out
207+ # y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
208+ y = incomplete_sentences .unsqueeze (- 1 ).repeat (1 , beam_size ).view (bsz , - 1 )
209+
210+ # now can use i to extract those beams that were selected
211+ # new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
212+ new_incomplete_sentences = y .gather (index = i , dim = 1 ).view (bsz * beam_size , 1 ).squeeze (- 1 )
213+
214+ # new_scores has shape (batch_size, B)
215+ new_scores = v
216+
217+ return new_decoder_tokens , new_scores , new_incomplete_sentences
218+
219+
220+ def generate (encoder_tokens : Tensor , eos_idx : int , model : T5Model , beam_size : int ) -> Tensor :
165221
166222 # pass tokens through encoder
223+ bsz = encoder_tokens .size (0 )
167224 encoder_padding_mask = encoder_tokens .eq (model .padding_idx )
168225 encoder_embeddings = model .dropout1 (model .token_embeddings (encoder_tokens ))
169226 encoder_output = model .encoder (encoder_embeddings , tgt_key_padding_mask = encoder_padding_mask )[0 ]
@@ -172,14 +229,22 @@ def greedy_generator(
172229 encoder_output = model .dropout2 (encoder_output )
173230
174231 # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
175- decoder_tokens = torch .ones ((encoder_tokens .size (0 ), 1 ), dtype = torch .long ) * model .padding_idx
232+ decoder_tokens = torch .ones ((bsz , 1 ), dtype = torch .long ) * model .padding_idx
233+ scores = torch .zeros ((bsz , beam_size ))
176234
177235 # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
178- incomplete_sentences = torch .ones (( encoder_tokens . size ( 0 ), 1 ) , dtype = torch .long )
236+ incomplete_sentences = torch .ones (bsz * beam_size , dtype = torch .long )
179237
180238 # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
181239 for step in range (model .config .max_seq_len ):
182240
241+ if step == 1 :
242+ # duplicate and order encoder output so that each beam is treated as its own independent sequence
243+ new_order = torch .arange (bsz ).view (- 1 , 1 ).repeat (1 , beam_size ).view (- 1 )
244+ new_order = new_order .to (encoder_tokens .device ).long ()
245+ encoder_output = encoder_output .index_select (0 , new_order )
246+ encoder_padding_mask = encoder_padding_mask .index_select (0 , new_order )
247+
183248 # causal mask and padding mask for decoder sequence
184249 tgt_len = decoder_tokens .shape [1 ]
185250 decoder_mask = torch .triu (torch .ones ((tgt_len , tgt_len ), dtype = torch .float64 ), diagonal = 1 ).bool ()
@@ -203,39 +268,39 @@ def greedy_generator(
203268 decoder_output = decoder_output * (model .config .embedding_dim ** - 0.5 )
204269 decoder_output = model .lm_head (decoder_output )
205270
206- # greedy search for next token to add to sequence
207- probs = F .log_softmax (decoder_output [:, - 1 ], dim = - 1 )
208- _ , next_token = torch .topk (probs , 1 )
209-
210- # ignore next tokens for sentences that are already complete
211- next_token *= incomplete_sentences
271+ decoder_tokens , scores , incomplete_sentences = beam_search (
272+ beam_size , step + 1 , bsz , decoder_output , decoder_tokens , scores , incomplete_sentences
273+ )
274+ # ignore newest tokens for sentences that are already complete
275+ decoder_tokens [:, - 1 ] *= incomplete_sentences
212276
213277 # update incomplete_sentences to remove those that were just ended
214- incomplete_sentences = incomplete_sentences - (next_token == eos_idx ).long ()
215-
216- # update decoder sequences to include new tokens
217- decoder_tokens = torch .cat ((decoder_tokens , next_token ), 1 )
278+ incomplete_sentences = incomplete_sentences - (decoder_tokens [:, - 1 ] == eos_idx ).long ()
218279
219280 # early stop if all sentences have been ended
220281 if (incomplete_sentences == 0 ).all ():
221282 break
222283
284+ # take most likely sequence
285+ decoder_tokens = decoder_tokens .view (bsz , beam_size , - 1 )[:, 0 , :]
223286 return decoder_tokens
224287
225288
226289#######################################################################
227290# Generate Summaries
228291# ------------------
229292#
230- # Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set.
293+ # Finally we put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
294+ # using a beam size of 3.
231295#
232296
233297batch = next (iter (test_dataloader ))
234298input_text = batch ["article" ]
235299model_input = transform (input_text )
236300target = batch ["abstract" ]
301+ beam_size = 3
237302
238- model_output = greedy_generator (model = model , encoder_tokens = model_input , eos_idx = eos_idx )
303+ model_output = generate (model = model , encoder_tokens = model_input , eos_idx = eos_idx , beam_size = beam_size )
239304output_text = transform .decode (model_output .tolist ())
240305
241306for i in range (batch_size ):
@@ -253,10 +318,10 @@ def greedy_generator(
253318#
254319# Example 1:
255320#
256- # prediction: the Palestinians officially become the 123rd member of the international
257- # criminal court . the move gives the court jurisdiction over alleged crimes committed
258- # in the occupied Palestinian territory . the ICC opened a preliminary examination into
259- # the situation in the occupied territories .
321+ # prediction: the Palestinians become the 123rd member of the international criminal
322+ # court . the accession was marked by a ceremony at the Hague, where the court is based .
323+ # the ICC opened a preliminary examination into the situation in the occupied
324+ # Palestinian territory .
260325#
261326# target: Membership gives the ICC jurisdiction over alleged crimes committed in
262327# Palestinian territories since last June . Israel and the United States opposed the
@@ -265,10 +330,10 @@ def greedy_generator(
265330#
266331# Example 2:
267332#
268- # prediction: a stray pooch in Washington state has used up at least three of her own
269- # after being hit by a car . the dog staggers to a nearby farm, dirt-covered and
270- # emaciated, where she is found . she suffered a dislocated jaw, leg injuries and a
271- # caved-in sinus cavity .
333+ # prediction: a stray pooch has used up at least three of her own after being hit by a
334+ # car and buried in a field . the dog managed to stagger to a nearby farm, dirt-covered
335+ # and emaciated, where she was found . she suffered a dislocated jaw, leg injuries and a
336+ # caved-in sinus cavity -- and still requires surgery to help her breathe .
272337#
273338# target: Theia, a bully breed mix, was apparently hit by a car, whacked with a hammer
274339# and buried in a field . "She's a true miracle dog and she deserves a good life," says
@@ -277,9 +342,9 @@ def greedy_generator(
277342#
278343# Example 3:
279344#
280- # prediction: mohammad Javad Zarif is the foreign minister of the country . he has been
281- # a key figure in securing a breakthrough in nuclear talks . he has been a hero in the
282- # international community .
345+ # prediction: mohammad Javad Zarif arrived in Iran on a sunny friday morning . he has gone
346+ # a long way to bring Iran in from the cold and allow it to rejoin the international
347+ # community . but there are some facts about him that are less well-known .
283348#
284349# target: Mohammad Javad Zarif has spent more time with John Kerry than any other
285350# foreign minister . He once participated in a takeover of the Iranian Consulate in San
@@ -288,9 +353,9 @@ def greedy_generator(
288353#
289354# Example 4:
290355#
291- # prediction: five americans were monitored for three weeks after being exposed to
292- # Ebola . one of the five had a heart-related issue on Saturday and has been discharged .
293- # none of the patients developed the deadly virus .
356+ # prediction: five americans were monitored for three weeks after being exposed to Ebola in
357+ # west africa . one of the five had a heart-related issue and has been discharged but hasn't
358+ # left the area . they are clinicians for Partners in Health, a Boston-based aid group .
294359#
295360# target: 17 Americans were exposed to the Ebola virus while in Sierra Leone in March .
296361# Another person was diagnosed with the disease and taken to hospital in Maryland .
@@ -302,7 +367,8 @@ def greedy_generator(
302367#
303368# prediction: the student was identified during an investigation by campus police and
304369# the office of student affairs . he admitted to placing the noose on the tree early
305- # Wednesday morning .
370+ # Wednesday morning . the incident is one of several recent racist events to affect
371+ # college students .
306372#
307373# target: Student is no longer on Duke University campus and will face disciplinary
308374# review . School officials identified student during investigation and the person
0 commit comments