1+ import warnings
12from typing import Any , Dict , List , Optional , Tuple , Union
23
34import torch
45import torch .nn as nn
56import torch .nn .functional as F
67from torch import Tensor
78from torchtext .prototype .models import (
9+ T5_11B_GENERATION ,
10+ T5_3B_GENERATION ,
811 T5_BASE_GENERATION ,
9- T5_SMALL_GENERATION ,
1012 T5_LARGE_GENERATION ,
11- T5_3B_GENERATION ,
12- T5_11B_GENERATION ,
13+ T5_SMALL_GENERATION ,
14+ T5Bundle ,
1315 T5Conf ,
1416 T5Transform ,
15- T5Bundle ,
1617)
1718
18- import warnings
19-
2019
2120BUNDLERS = {
2221 "base" : T5_BASE_GENERATION ,
@@ -139,7 +138,6 @@ def beam_search(
139138 return new_decoder_tokens , new_scores , new_incomplete_sentences
140139
141140 def generate (self , encoder_tokens : Tensor , beam_size : int , eos_idx : int = 1 , max_seq_len : int = 512 ) -> Tensor :
142-
143141 # pass tokens through encoder
144142 bsz = encoder_tokens .size (0 )
145143 encoder = self .model .get_encoder ()
@@ -155,7 +153,6 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max
155153
156154 # iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
157155 for step in range (max_seq_len ):
158-
159156 if step == 1 :
160157 # duplicate and order encoder output so that each beam is treated as its own independent sequence
161158 encoder_output = encoder_outputs .get ("encoder_output" )
@@ -189,7 +186,6 @@ def generate(self, encoder_tokens: Tensor, beam_size: int, eos_idx: int = 1, max
189186 return decoder_tokens
190187
191188 def forward (self , input_text : List [str ], beam_size : int , max_seq_len : int ) -> Union [List [str ], str ]:
192-
193189 model_input = self .transform (input_text )
194190 model_output_tensor = self .generate (encoder_tokens = model_input , beam_size = beam_size , max_seq_len = max_seq_len )
195191 model_output_list = torch .jit .annotate (List [List [int ]], model_output_tensor .tolist ())
0 commit comments