Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit c34c150

Browse files
datumboxfacebook-github-bot
authored andcommitted
Pass an embedding layer to the constructor of the BertModel class (#1135)
Reviewed By: zhangguanheng66 Differential Revision: D26369001 fbshipit-source-id: f5a67a2a812d568073505ec4d181f6e418eb4a3f
1 parent 125684c commit c34c150

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

examples/BERT/model.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5):
4343
self.norm = LayerNorm(ninp)
4444
self.dropout = Dropout(dropout)
4545

46-
def forward(self, src, token_type_input):
46+
def forward(self, seq_inputs):
47+
src, token_type_input = seq_inputs
4748
src = self.embed(src) + self.pos_embed(src) \
4849
+ self.tok_type_embed(src, token_type_input)
4950
return self.dropout(self.norm(src))
@@ -99,16 +100,16 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
99100
class BertModel(nn.Module):
100101
"""Contain a transformer encoder."""
101102

102-
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
103+
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5):
103104
super(BertModel, self).__init__()
104105
self.model_type = 'Transformer'
105-
self.bert_embed = BertEmbedding(ntoken, ninp)
106+
self.bert_embed = embed_layer
106107
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
107108
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
108109
self.ninp = ninp
109110

110-
def forward(self, src, token_type_input):
111-
src = self.bert_embed(src, token_type_input)
111+
def forward(self, seq_inputs):
112+
src = self.bert_embed(seq_inputs)
112113
output = self.transformer_encoder(src)
113114
return output
114115

@@ -118,15 +119,16 @@ class MLMTask(nn.Module):
118119

119120
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
120121
super(MLMTask, self).__init__()
121-
self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, dropout=0.5)
122+
embed_layer = BertEmbedding(ntoken, ninp)
123+
self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5)
122124
self.mlm_span = Linear(ninp, ninp)
123125
self.activation = F.gelu
124126
self.norm_layer = LayerNorm(ninp, eps=1e-12)
125127
self.mlm_head = Linear(ninp, ntoken)
126128

127129
def forward(self, src, token_type_input=None):
128130
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
129-
output = self.bert_model(src, token_type_input)
131+
output = self.bert_model((src, token_type_input))
130132
output = self.mlm_span(output)
131133
output = self.activation(output)
132134
output = self.norm_layer(output)
@@ -147,7 +149,7 @@ def __init__(self, bert_model):
147149

148150
def forward(self, src, token_type_input):
149151
src = src.transpose(0, 1) # Wrap up by nn.DataParallel
150-
output = self.bert_model(src, token_type_input)
152+
output = self.bert_model((src, token_type_input))
151153
# Send the first <'cls'> seq to a classifier
152154
output = self.activation(self.linear_layer(output[0]))
153155
output = self.ns_span(output)
@@ -164,7 +166,7 @@ def __init__(self, bert_model):
164166
self.qa_span = Linear(bert_model.ninp, 2)
165167

166168
def forward(self, src, token_type_input):
167-
output = self.bert_model(src, token_type_input)
169+
output = self.bert_model((src, token_type_input))
168170
# transpose output (S, N, E) to (N, S, E)
169171
output = output.transpose(0, 1)
170172
output = self.activation(output)

examples/BERT/ns_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn as nn
66
from torch.nn.parallel import DistributedDataParallel as DDP
77
from torch.utils.data import DataLoader
8-
from model import NextSentenceTask, BertModel
8+
from model import NextSentenceTask, BertModel, BertEmbedding
99
from utils import run_demo, run_ddp, wrap_up
1010

1111

@@ -149,7 +149,8 @@ def run_main(args, rank=None):
149149
if args.checkpoint != 'None':
150150
model = torch.load(args.checkpoint)
151151
else:
152-
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
152+
embed_layer = BertEmbedding(len(vocab), args.emsize)
153+
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout)
153154
pretrained_bert.load_state_dict(torch.load(args.bert_model))
154155
model = NextSentenceTask(pretrained_bert)
155156

examples/BERT/qa_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from model import QuestionAnswerTask
1010
from metrics import compute_qa_exact, compute_qa_f1
1111
from utils import print_loss_log
12-
from model import BertModel
12+
from model import BertModel, BertEmbedding
1313

1414

1515
def process_raw_data(data):
@@ -174,7 +174,8 @@ def train():
174174
train_dataset = process_raw_data(train_dataset)
175175
dev_dataset = process_raw_data(dev_dataset)
176176
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
177-
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)
177+
embed_layer = BertEmbedding(len(vocab), args.emsize)
178+
pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout)
178179
pretrained_bert.load_state_dict(torch.load(args.bert_model))
179180
model = QuestionAnswerTask(pretrained_bert).to(device)
180181
criterion = nn.CrossEntropyLoss()

0 commit comments

Comments
 (0)