@@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
4343 self .norm = LayerNorm (ninp )
4444 self .dropout = Dropout (dropout )
4545
46- def forward (self , src , token_type_input ):
46+ def forward (self , seq_inputs ):
47+ src , token_type_input = seq_inputs
4748 src = self .embed (src ) + self .pos_embed (src ) \
4849 + self .tok_type_embed (src , token_type_input )
4950 return self .dropout (self .norm (src ))
@@ -99,16 +100,16 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
99100class BertModel (nn .Module ):
100101 """Contain a transformer encoder."""
101102
102- def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
103+ def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , embed_layer , dropout = 0.5 ):
103104 super (BertModel , self ).__init__ ()
104105 self .model_type = 'Transformer'
105- self .bert_embed = BertEmbedding ( ntoken , ninp )
106+ self .bert_embed = embed_layer
106107 encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
107108 self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
108109 self .ninp = ninp
109110
110- def forward (self , src , token_type_input ):
111- src = self .bert_embed (src , token_type_input )
111+ def forward (self , seq_inputs ):
112+ src = self .bert_embed (seq_inputs )
112113 output = self .transformer_encoder (src )
113114 return output
114115
@@ -118,15 +119,16 @@ class MLMTask(nn.Module):
118119
119120 def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
120121 super (MLMTask , self ).__init__ ()
121- self .bert_model = BertModel (ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 )
122+ embed_layer = BertEmbedding (ntoken , ninp )
123+ self .bert_model = BertModel (ntoken , ninp , nhead , nhid , nlayers , embed_layer , dropout = 0.5 )
122124 self .mlm_span = Linear (ninp , ninp )
123125 self .activation = F .gelu
124126 self .norm_layer = LayerNorm (ninp , eps = 1e-12 )
125127 self .mlm_head = Linear (ninp , ntoken )
126128
127129 def forward (self , src , token_type_input = None ):
128130 src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
129- output = self .bert_model (src , token_type_input )
131+ output = self .bert_model (( src , token_type_input ) )
130132 output = self .mlm_span (output )
131133 output = self .activation (output )
132134 output = self .norm_layer (output )
@@ -147,7 +149,7 @@ def __init__(self, bert_model):
147149
148150 def forward (self , src , token_type_input ):
149151 src = src .transpose (0 , 1 ) # Wrap up by nn.DataParallel
150- output = self .bert_model (src , token_type_input )
152+ output = self .bert_model (( src , token_type_input ) )
151153 # Send the first <'cls'> seq to a classifier
152154 output = self .activation (self .linear_layer (output [0 ]))
153155 output = self .ns_span (output )
@@ -164,7 +166,7 @@ def __init__(self, bert_model):
164166 self .qa_span = Linear (bert_model .ninp , 2 )
165167
166168 def forward (self , src , token_type_input ):
167- output = self .bert_model (src , token_type_input )
169+ output = self .bert_model (( src , token_type_input ) )
168170 # transpose output (S, N, E) to (N, S, E)
169171 output = output .transpose (0 , 1 )
170172 output = self .activation (output )
0 commit comments