@@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
4343 self .norm = LayerNorm (ninp )
4444 self .dropout = Dropout (dropout )
4545
46- def forward (self , src , token_type_input ):
46+ def forward (self , seq_inputs ):
47+ src , token_type_input = seq_inputs
4748 src = self .embed (src ) + self .pos_embed (src ) \
4849 + self .tok_type_embed (src , token_type_input )
4950 return self .dropout (self .norm (src ))
@@ -107,8 +108,8 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5)
107108 self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
108109 self .ninp = ninp
109110
110- def forward (self , src , token_type_input ):
111- src = self .bert_embed (src , token_type_input )
111+ def forward (self , seq_inputs ):
112+ src = self .bert_embed (seq_inputs )
112113 output = self .transformer_encoder (src )
113114 return output
114115
@@ -127,7 +128,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
127128
128129 def forward (self , src , token_type_input = None ):
129130 src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
130- output = self .bert_model (src , token_type_input )
131+ output = self .bert_model (( src , token_type_input ) )
131132 output = self .mlm_span (output )
132133 output = self .activation (output )
133134 output = self .norm_layer (output )
@@ -148,7 +149,7 @@ def __init__(self, bert_model):
148149
149150 def forward (self , src , token_type_input ):
150151 src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
151- output = self .bert_model (src , token_type_input )
152+ output = self .bert_model (( src , token_type_input ) )
152153 # Send the first <'cls'> seq to a classifier
153154 output = self .activation (self .linear_layer (output [0 ]))
154155 output = self .ns_span (output )
@@ -165,7 +166,7 @@ def __init__(self, bert_model):
165166 self .qa_span = Linear (bert_model .ninp , 2 )
166167
167168 def forward (self , src , token_type_input ):
168- output = self .bert_model (src , token_type_input )
169+ output = self .bert_model (( src , token_type_input ) )
169170 # transpose output (S, N, E) to (N, S, E)
170171 output = output .transpose (0 , 1 )
171172 output = self .activation (output )
0 commit comments