|
3 | 3 | ========================================================================== |
4 | 4 |
|
5 | 5 | **Author**: `Pendo Abbo <[email protected]>`__ |
| 6 | +**Author**: `Joe Cummings <[email protected]>`__ |
6 | 7 |
|
7 | 8 | """ |
8 | 9 |
|
|
24 | 25 | # Common imports |
25 | 26 | # -------------- |
26 | 27 | import torch |
27 | | -import torch.nn.functional as F |
28 | 28 |
|
29 | 29 | DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
30 | 30 |
|
|
47 | 47 | # the T5 model expects the input to be batched. |
48 | 48 | # |
49 | 49 |
|
50 | | -from torchtext.prototype.models import T5Transform |
| 50 | +from torchtext.models import T5Transform |
51 | 51 |
|
52 | 52 | padding_idx = 0 |
53 | 53 | eos_idx = 1 |
|
66 | 66 | # |
67 | 67 | # :: |
68 | 68 | # |
69 | | -# from torchtext.prototype.models import T5_BASE_GENERATION |
| 69 | +# from torchtext.models import T5_BASE_GENERATION |
70 | 70 | # transform = T5_BASE_GENERATION.transform() |
71 | 71 | # |
72 | 72 |
|
|
81 | 81 | # https://pytorch.org/text/main/models.html |
82 | 82 | # |
83 | 83 | # |
84 | | -from torchtext.prototype.models import T5_BASE_GENERATION |
| 84 | +from torchtext.models import T5_BASE_GENERATION |
85 | 85 |
|
86 | 86 |
|
87 | 87 | t5_base = T5_BASE_GENERATION |
|
92 | 92 |
|
93 | 93 |
|
94 | 94 | ####################################################################### |
95 | | -# Sequence Generator |
| 95 | +# GenerationUtils |
96 | 96 | # ------------------ |
97 | 97 | # |
98 | | -# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the |
| 98 | +# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the |
99 | 99 | # model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated |
100 | | -# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger |
101 | | -# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to |
102 | | -# a greedy decoder. |
103 | | -# |
104 | | - |
105 | | -from torch import Tensor |
106 | | -from torchtext.prototype.models import T5Model |
107 | | - |
108 | | - |
109 | | -def beam_search( |
110 | | - beam_size: int, |
111 | | - step: int, |
112 | | - bsz: int, |
113 | | - decoder_output: Tensor, |
114 | | - decoder_tokens: Tensor, |
115 | | - scores: Tensor, |
116 | | - incomplete_sentences: Tensor, |
117 | | -): |
118 | | - probs = F.log_softmax(decoder_output[:, -1], dim=-1) |
119 | | - top = torch.topk(probs, beam_size) |
120 | | - |
121 | | - # N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size |
122 | | - # decoder_tokens has shape (N,L) -> (N,B,L) |
123 | | - # top.indices has shape (N,B) - > (N,B,1) |
124 | | - # x has shape (N,B,L+1) |
125 | | - # note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size |
126 | | - x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1) |
127 | | - |
128 | | - # beams are first created for a given sequence |
129 | | - if step == 1: |
130 | | - # x has shape (batch_size, B, L+1) -> (batch_size * B, L+1) |
131 | | - # new_scores has shape (batch_size,B) |
132 | | - # incomplete_sentences has shape (batch_size * B) = (N) |
133 | | - new_decoder_tokens = x.view(-1, step + 1) |
134 | | - new_scores = top.values |
135 | | - new_incomplete_sentences = incomplete_sentences |
136 | | - |
137 | | - # beams already exist, want to expand each beam into possible new tokens to add |
138 | | - # and for all expanded beams beloning to the same sequences, choose the top k |
139 | | - else: |
140 | | - # scores has shape (batch_size,B) -> (N,1) -> (N,B) |
141 | | - # top.values has shape (N,B) |
142 | | - # new_scores has shape (N,B) -> (batch_size, B^2) |
143 | | - new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1) |
144 | | - |
145 | | - # v, i have shapes (batch_size, B) |
146 | | - v, i = torch.topk(new_scores, beam_size) |
147 | | - |
148 | | - # x has shape (N,B,L+1) -> (batch_size, B, L+1) |
149 | | - # i has shape (batch_size, B) -> (batch_size, B, L+1) |
150 | | - # new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L) |
151 | | - x = x.view(bsz, -1, step + 1) |
152 | | - new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1) |
153 | | - |
154 | | - # need to update incomplete sentences in case one of the beams was kicked out |
155 | | - # y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2) |
156 | | - y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1) |
157 | | - |
158 | | - # now can use i to extract those beams that were selected |
159 | | - # new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N |
160 | | - new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1) |
161 | | - |
162 | | - # new_scores has shape (batch_size, B) |
163 | | - new_scores = v |
164 | | - |
165 | | - return new_decoder_tokens, new_scores, new_incomplete_sentences |
166 | | - |
167 | | - |
168 | | -def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor: |
169 | | - |
170 | | - # pass tokens through encoder |
171 | | - bsz = encoder_tokens.size(0) |
172 | | - encoder_padding_mask = encoder_tokens.eq(model.padding_idx) |
173 | | - encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens)) |
174 | | - encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0] |
175 | | - |
176 | | - encoder_output = model.norm1(encoder_output) |
177 | | - encoder_output = model.dropout2(encoder_output) |
178 | | - |
179 | | - # initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence |
180 | | - decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx |
181 | | - scores = torch.zeros((bsz, beam_size)) |
182 | | - |
183 | | - # mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet |
184 | | - incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long) |
185 | | - |
186 | | - # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token |
187 | | - for step in range(model.config.max_seq_len): |
188 | | - |
189 | | - if step == 1: |
190 | | - # duplicate and order encoder output so that each beam is treated as its own independent sequence |
191 | | - new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) |
192 | | - new_order = new_order.to(encoder_tokens.device).long() |
193 | | - encoder_output = encoder_output.index_select(0, new_order) |
194 | | - encoder_padding_mask = encoder_padding_mask.index_select(0, new_order) |
195 | | - |
196 | | - # causal mask and padding mask for decoder sequence |
197 | | - tgt_len = decoder_tokens.shape[1] |
198 | | - decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool() |
199 | | - decoder_padding_mask = decoder_tokens.eq(model.padding_idx) |
200 | | - |
201 | | - # T5 implemention uses padding idx to start sequence. Want to ignore this when masking |
202 | | - decoder_padding_mask[:, 0] = False |
203 | | - |
204 | | - # pass decoder sequence through decoder |
205 | | - decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens)) |
206 | | - decoder_output = model.decoder( |
207 | | - decoder_embeddings, |
208 | | - memory=encoder_output, |
209 | | - tgt_mask=decoder_mask, |
210 | | - tgt_key_padding_mask=decoder_padding_mask, |
211 | | - memory_key_padding_mask=encoder_padding_mask, |
212 | | - )[0] |
213 | | - |
214 | | - decoder_output = model.norm2(decoder_output) |
215 | | - decoder_output = model.dropout4(decoder_output) |
216 | | - decoder_output = decoder_output * (model.config.embedding_dim ** -0.5) |
217 | | - decoder_output = model.lm_head(decoder_output) |
218 | | - |
219 | | - decoder_tokens, scores, incomplete_sentences = beam_search( |
220 | | - beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences |
221 | | - ) |
222 | | - # ignore newest tokens for sentences that are already complete |
223 | | - decoder_tokens[:, -1] *= incomplete_sentences |
224 | | - |
225 | | - # update incomplete_sentences to remove those that were just ended |
226 | | - incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long() |
227 | | - |
228 | | - # early stop if all sentences have been ended |
229 | | - if (incomplete_sentences == 0).all(): |
230 | | - break |
231 | | - |
232 | | - # take most likely sequence |
233 | | - decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :] |
234 | | - return decoder_tokens |
| 100 | +# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and |
| 101 | +# other decoding strategies are also supported. |
| 102 | +# |
| 103 | +# |
| 104 | +from torchtext.prototype.generate import GenerationUtils |
| 105 | + |
| 106 | +sequence_generator = GenerationUtils(model) |
235 | 107 |
|
236 | 108 |
|
237 | 109 | ####################################################################### |
@@ -343,16 +215,16 @@ def process_labels(labels, x): |
343 | 215 | # ------------------ |
344 | 216 | # |
345 | 217 | # We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set |
346 | | -# using a beam size of 3. |
| 218 | +# using a beam size of 1. |
347 | 219 | # |
348 | 220 |
|
349 | 221 | batch = next(iter(cnndm_dataloader)) |
350 | 222 | input_text = batch["article"] |
351 | 223 | target = batch["abstract"] |
352 | | -beam_size = 3 |
| 224 | +beam_size = 1 |
353 | 225 |
|
354 | 226 | model_input = transform(input_text) |
355 | | -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 227 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
356 | 228 | output_text = transform.decode(model_output.tolist()) |
357 | 229 |
|
358 | 230 | for i in range(cnndm_batch_size): |
@@ -442,7 +314,7 @@ def process_labels(labels, x): |
442 | 314 | beam_size = 1 |
443 | 315 |
|
444 | 316 | model_input = transform(input_text) |
445 | | -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 317 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
446 | 318 | output_text = transform.decode(model_output.tolist()) |
447 | 319 |
|
448 | 320 | for i in range(imdb_batch_size): |
@@ -536,7 +408,7 @@ def process_labels(labels, x): |
536 | 408 | beam_size = 4 |
537 | 409 |
|
538 | 410 | model_input = transform(input_text) |
539 | | -model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size) |
| 411 | +model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size) |
540 | 412 | output_text = transform.decode(model_output.tolist()) |
541 | 413 |
|
542 | 414 | for i in range(multi_batch_size): |
|
0 commit comments