Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit beeab99

Browse files
author
Guanheng Zhang
committed
BertEmbedding to accept one input tuple in forward func
1 parent 8bcffe6 commit beeab99

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

examples/BERT/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
4343
self.norm = LayerNorm(ninp)
4444
self.dropout = Dropout(dropout)
4545

46-
def forward(self, src, token_type_input):
46+
def forward(self, seq_inputs):
47+
src, token_type_input = seq_inputs
4748
src = self.embed(src) + self.pos_embed(src) \
4849
+ self.tok_type_embed(src, token_type_input)
4950
return self.dropout(self.norm(src))
@@ -107,8 +108,8 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5)
107108
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
108109
self.ninp = ninp
109110

110-
def forward(self, src, token_type_input):
111-
src = self.bert_embed(src, token_type_input)
111+
def forward(self, seq_inputs):
112+
src = self.bert_embed(seq_inputs)
112113
output = self.transformer_encoder(src)
113114
return output
114115

@@ -127,7 +128,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
127128

128129
def forward(self, src, token_type_input=None):
129130
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
130-
output = self.bert_model(src, token_type_input)
131+
output = self.bert_model((src, token_type_input))
131132
output = self.mlm_span(output)
132133
output = self.activation(output)
133134
output = self.norm_layer(output)
@@ -148,7 +149,7 @@ def __init__(self, bert_model):
148149

149150
def forward(self, src, token_type_input):
150151
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
151-
output = self.bert_model(src, token_type_input)
152+
output = self.bert_model((src, token_type_input))
152153
# Send the first <'cls'> seq to a classifier
153154
output = self.activation(self.linear_layer(output[0]))
154155
output = self.ns_span(output)
@@ -165,7 +166,7 @@ def __init__(self, bert_model):
165166
self.qa_span = Linear(bert_model.ninp, 2)
166167

167168
def forward(self, src, token_type_input):
168-
output = self.bert_model(src, token_type_input)
169+
output = self.bert_model((src, token_type_input))
169170
# transpose output (S, N, E) to (N, S, E)
170171
output = output.transpose(0, 1)
171172
output = self.activation(output)

0 commit comments

Comments
 (0)