From a726b93005607bbc8ce1a7ef9c6ad8f5097c9dcc Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 28 Jan 2021 08:55:37 -0800 Subject: [PATCH 1/3] Pass embed layer to BERTModel --- examples/BERT/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/BERT/model.py b/examples/BERT/model.py index 316841c751..20f15b4501 100644 --- a/examples/BERT/model.py +++ b/examples/BERT/model.py @@ -99,10 +99,10 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None): class BertModel(nn.Module): """Contain a transformer encoder.""" - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): + def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5): super(BertModel, self).__init__() self.model_type = 'Transformer' - self.bert_embed = BertEmbedding(ntoken, ninp) + self.bert_embed = embed_layer encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) self.ninp = ninp @@ -118,7 +118,8 @@ class MLMTask(nn.Module): def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): super(MLMTask, self).__init__() - self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, dropout=0.5) + embed_layer = BertEmbedding(ntoken, ninp) + self.bert_model = BertModel(ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5) self.mlm_span = Linear(ninp, ninp) self.activation = F.gelu self.norm_layer = LayerNorm(ninp, eps=1e-12) From 8bcffe6d023dc8706d1b8c37d8f67e3e7c3b3ae6 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 28 Jan 2021 09:52:32 -0800 Subject: [PATCH 2/3] update ns and qa task --- examples/BERT/ns_task.py | 5 +++-- examples/BERT/qa_task.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/BERT/ns_task.py b/examples/BERT/ns_task.py index 06786a82dc..3084686ebb 100644 --- a/examples/BERT/ns_task.py +++ b/examples/BERT/ns_task.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader -from model import NextSentenceTask, BertModel +from model import NextSentenceTask, BertModel, BertEmbedding from utils import run_demo, run_ddp, wrap_up @@ -149,7 +149,8 @@ def run_main(args, rank=None): if args.checkpoint != 'None': model = torch.load(args.checkpoint) else: - pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + embed_layer = BertEmbedding(len(vocab), args.emsize) + pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout) pretrained_bert.load_state_dict(torch.load(args.bert_model)) model = NextSentenceTask(pretrained_bert) diff --git a/examples/BERT/qa_task.py b/examples/BERT/qa_task.py index 72595c101d..b2239bc612 100644 --- a/examples/BERT/qa_task.py +++ b/examples/BERT/qa_task.py @@ -9,7 +9,7 @@ from model import QuestionAnswerTask from metrics import compute_qa_exact, compute_qa_f1 from utils import print_loss_log -from model import BertModel +from model import BertModel, BertEmbedding def process_raw_data(data): @@ -174,7 +174,8 @@ def train(): train_dataset = process_raw_data(train_dataset) dev_dataset = process_raw_data(dev_dataset) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout) + embed_layer = BertEmbedding(len(vocab), args.emsize) + pretrained_bert = BertModel(len(vocab), args.emsize, args.nhead, args.nhid, args.nlayers, embed_layer, args.dropout) pretrained_bert.load_state_dict(torch.load(args.bert_model)) model = QuestionAnswerTask(pretrained_bert).to(device) criterion = nn.CrossEntropyLoss() From beeab99aa967f8e700e9027a200aa3ebe83661b2 Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Thu, 28 Jan 2021 11:24:22 -0800 Subject: [PATCH 3/3] BertEmbedding to accept one input tuple in forward func --- examples/BERT/model.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/BERT/model.py b/examples/BERT/model.py index 20f15b4501..484117e19c 100644 --- a/examples/BERT/model.py +++ b/examples/BERT/model.py @@ -43,7 +43,8 @@ def __init__(self, ntoken, ninp, dropout=0.5): self.norm = LayerNorm(ninp) self.dropout = Dropout(dropout) - def forward(self, src, token_type_input): + def forward(self, seq_inputs): + src, token_type_input = seq_inputs src = self.embed(src) + self.pos_embed(src) \ + self.tok_type_embed(src, token_type_input) return self.dropout(self.norm(src)) @@ -107,8 +108,8 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, embed_layer, dropout=0.5) self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) self.ninp = ninp - def forward(self, src, token_type_input): - src = self.bert_embed(src, token_type_input) + def forward(self, seq_inputs): + src = self.bert_embed(seq_inputs) output = self.transformer_encoder(src) return output @@ -127,7 +128,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): def forward(self, src, token_type_input=None): src = src.transpose(0, 1) # Wrap up by nn.DataParallel - output = self.bert_model(src, token_type_input) + output = self.bert_model((src, token_type_input)) output = self.mlm_span(output) output = self.activation(output) output = self.norm_layer(output) @@ -148,7 +149,7 @@ def __init__(self, bert_model): def forward(self, src, token_type_input): src = src.transpose(0, 1) # Wrap up by nn.DataParallel - output = self.bert_model(src, token_type_input) + output = self.bert_model((src, token_type_input)) # Send the first <'cls'> seq to a classifier output = self.activation(self.linear_layer(output[0])) output = self.ns_span(output) @@ -165,7 +166,7 @@ def __init__(self, bert_model): self.qa_span = Linear(bert_model.ninp, 2) def forward(self, src, token_type_input): - output = self.bert_model(src, token_type_input) + output = self.bert_model((src, token_type_input)) # transpose output (S, N, E) to (N, S, E) output = output.transpose(0, 1) output = self.activation(output)