Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 74 additions & 77 deletions pytorch_transformers/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,18 +509,18 @@ def _init_weights(self, module):
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Expand Down Expand Up @@ -581,7 +581,7 @@ def _prune_heads(self, heads_to_prune):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
Expand Down Expand Up @@ -684,10 +684,14 @@ def tie_weights(self):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)

def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
next_sentence_label=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
Expand Down Expand Up @@ -752,10 +756,14 @@ def tie_weights(self):
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)

def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
Expand Down Expand Up @@ -809,10 +817,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
next_sentence_label=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

pooled_output = outputs[1]

seq_relationship_score = self.cls(pooled_output)
Expand Down Expand Up @@ -870,10 +883,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to solve #1010


outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
Expand All @@ -896,45 +914,9 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No

@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
BERT_START_DOCSTRING)
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
class BertForMultipleChoice(BertPreTrainedModel):
r"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:

(a) For sequence pairs:

``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``

``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``

(b) For single sequences:

``tokens: [CLS] the dog is hairy . [SEP]``

``token_type_ids: 0 0 0 0 0 0 0``

Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
Expand Down Expand Up @@ -974,16 +956,21 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):
num_choices = input_ids.shape[1]

flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
attention_mask=flat_attention_mask, head_mask=head_mask)
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

pooled_output = outputs[1]

pooled_output = self.dropout(pooled_output)
Expand Down Expand Up @@ -1042,10 +1029,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
position_ids=None, head_mask=None, labels=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
Expand Down Expand Up @@ -1116,10 +1108,15 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
end_positions=None, position_ids=None, head_mask=None):
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
start_positions=None, end_positions=None):

outputs = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)

sequence_output = outputs[0]

logits = self.qa_outputs(sequence_output)
Expand Down
18 changes: 9 additions & 9 deletions pytorch_transformers/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,10 @@ def tie_weights(self):
self._tie_or_clone_weights(self.vocab_projector,
self.distilbert.embeddings.word_embeddings)

def forward(self, input_ids, attention_mask=None, masked_lm_labels=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
dlbrt_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
Expand Down Expand Up @@ -588,10 +588,10 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
Expand Down Expand Up @@ -662,10 +662,10 @@ def __init__(self, config):

self.init_weights()

def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
distilbert_output = self.distilbert(input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask)
attention_mask=attention_mask,
head_mask=head_mask)
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)

hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)
Expand Down
Loading