Skip to content

Commit 84d346b

Browse files
authored
Merge pull request #1195 from huggingface/reorder_arguments
[2.0] Reodering arguments for torch jit #1010 and future TF2.0 compatibility
2 parents 995e38b + 3f05de6 commit 84d346b

11 files changed

+337
-255
lines changed

pytorch_transformers/modeling_bert.py

Lines changed: 74 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -509,18 +509,18 @@ def _init_weights(self, module):
509509
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
510510
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
511511
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
512-
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
513-
Indices of positions of each input sequence tokens in the position embeddings.
514-
Selected in the range ``[0, config.max_position_embeddings - 1]``.
512+
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
513+
Mask to avoid performing attention on padding token indices.
514+
Mask values selected in ``[0, 1]``:
515+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
515516
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
516517
Segment token indices to indicate first and second portions of the inputs.
517518
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
518519
corresponds to a `sentence B` token
519520
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
520-
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
521-
Mask to avoid performing attention on padding token indices.
522-
Mask values selected in ``[0, 1]``:
523-
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
521+
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
522+
Indices of positions of each input sequence tokens in the position embeddings.
523+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
524524
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
525525
Mask to nullify selected heads of the self-attention modules.
526526
Mask values selected in ``[0, 1]``:
@@ -581,7 +581,7 @@ def _prune_heads(self, heads_to_prune):
581581
for layer, heads in heads_to_prune.items():
582582
self.encoder.layer[layer].attention.prune_heads(heads)
583583

584-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
584+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
585585
if attention_mask is None:
586586
attention_mask = torch.ones_like(input_ids)
587587
if token_type_ids is None:
@@ -684,10 +684,14 @@ def tie_weights(self):
684684
self._tie_or_clone_weights(self.cls.predictions.decoder,
685685
self.bert.embeddings.word_embeddings)
686686

687-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
688-
next_sentence_label=None, position_ids=None, head_mask=None):
689-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
690-
attention_mask=attention_mask, head_mask=head_mask)
687+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
688+
masked_lm_labels=None, next_sentence_label=None):
689+
690+
outputs = self.bert(input_ids,
691+
attention_mask=attention_mask,
692+
token_type_ids=token_type_ids,
693+
position_ids=position_ids,
694+
head_mask=head_mask)
691695

692696
sequence_output, pooled_output = outputs[:2]
693697
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@@ -752,10 +756,14 @@ def tie_weights(self):
752756
self._tie_or_clone_weights(self.cls.predictions.decoder,
753757
self.bert.embeddings.word_embeddings)
754758

755-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
756-
position_ids=None, head_mask=None):
757-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
758-
attention_mask=attention_mask, head_mask=head_mask)
759+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
760+
masked_lm_labels=None):
761+
762+
outputs = self.bert(input_ids,
763+
attention_mask=attention_mask,
764+
token_type_ids=token_type_ids,
765+
position_ids=position_ids,
766+
head_mask=head_mask)
759767

760768
sequence_output = outputs[0]
761769
prediction_scores = self.cls(sequence_output)
@@ -809,10 +817,15 @@ def __init__(self, config):
809817

810818
self.init_weights()
811819

812-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None,
813-
position_ids=None, head_mask=None):
814-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
815-
attention_mask=attention_mask, head_mask=head_mask)
820+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
821+
next_sentence_label=None):
822+
823+
outputs = self.bert(input_ids,
824+
attention_mask=attention_mask,
825+
token_type_ids=token_type_ids,
826+
position_ids=position_ids,
827+
head_mask=head_mask)
828+
816829
pooled_output = outputs[1]
817830

818831
seq_relationship_score = self.cls(pooled_output)
@@ -870,10 +883,15 @@ def __init__(self, config):
870883

871884
self.init_weights()
872885

873-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
874-
position_ids=None, head_mask=None):
875-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
876-
attention_mask=attention_mask, head_mask=head_mask)
886+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
887+
position_ids=None, head_mask=None, labels=None):
888+
889+
outputs = self.bert(input_ids,
890+
attention_mask=attention_mask,
891+
token_type_ids=token_type_ids,
892+
position_ids=position_ids,
893+
head_mask=head_mask)
894+
877895
pooled_output = outputs[1]
878896

879897
pooled_output = self.dropout(pooled_output)
@@ -896,45 +914,9 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
896914

897915
@add_start_docstrings("""Bert Model with a multiple choice classification head on top (a linear layer on top of
898916
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
899-
BERT_START_DOCSTRING)
917+
BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING)
900918
class BertForMultipleChoice(BertPreTrainedModel):
901919
r"""
902-
Inputs:
903-
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
904-
Indices of input sequence tokens in the vocabulary.
905-
The second dimension of the input (`num_choices`) indicates the number of choices to score.
906-
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
907-
908-
(a) For sequence pairs:
909-
910-
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
911-
912-
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
913-
914-
(b) For single sequences:
915-
916-
``tokens: [CLS] the dog is hairy . [SEP]``
917-
918-
``token_type_ids: 0 0 0 0 0 0 0``
919-
920-
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
921-
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
922-
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
923-
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
924-
Segment token indices to indicate first and second portions of the inputs.
925-
The second dimension of the input (`num_choices`) indicates the number of choices to score.
926-
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
927-
corresponds to a `sentence B` token
928-
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
929-
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
930-
Mask to avoid performing attention on padding token indices.
931-
The second dimension of the input (`num_choices`) indicates the number of choices to score.
932-
Mask values selected in ``[0, 1]``:
933-
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
934-
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
935-
Mask to nullify selected heads of the self-attention modules.
936-
Mask values selected in ``[0, 1]``:
937-
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
938920
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
939921
Labels for computing the multiple choice classification loss.
940922
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
@@ -974,16 +956,21 @@ def __init__(self, config):
974956

975957
self.init_weights()
976958

977-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
978-
position_ids=None, head_mask=None):
959+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
960+
position_ids=None, head_mask=None, labels=None):
979961
num_choices = input_ids.shape[1]
980962

981-
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
982-
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
983-
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
984-
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
985-
outputs = self.bert(flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids,
986-
attention_mask=flat_attention_mask, head_mask=head_mask)
963+
input_ids = input_ids.view(-1, input_ids.size(-1))
964+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
965+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
966+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
967+
968+
outputs = self.bert(input_ids,
969+
attention_mask=attention_mask,
970+
token_type_ids=token_type_ids,
971+
position_ids=position_ids,
972+
head_mask=head_mask)
973+
987974
pooled_output = outputs[1]
988975

989976
pooled_output = self.dropout(pooled_output)
@@ -1042,10 +1029,15 @@ def __init__(self, config):
10421029

10431030
self.init_weights()
10441031

1045-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
1046-
position_ids=None, head_mask=None):
1047-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
1048-
attention_mask=attention_mask, head_mask=head_mask)
1032+
def forward(self, input_ids, attention_mask=None, token_type_ids=None,
1033+
position_ids=None, head_mask=None, labels=None):
1034+
1035+
outputs = self.bert(input_ids,
1036+
attention_mask=attention_mask,
1037+
token_type_ids=token_type_ids,
1038+
position_ids=position_ids,
1039+
head_mask=head_mask)
1040+
10491041
sequence_output = outputs[0]
10501042

10511043
sequence_output = self.dropout(sequence_output)
@@ -1116,10 +1108,15 @@ def __init__(self, config):
11161108

11171109
self.init_weights()
11181110

1119-
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
1120-
end_positions=None, position_ids=None, head_mask=None):
1121-
outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
1122-
attention_mask=attention_mask, head_mask=head_mask)
1111+
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
1112+
start_positions=None, end_positions=None):
1113+
1114+
outputs = self.bert(input_ids,
1115+
attention_mask=attention_mask,
1116+
token_type_ids=token_type_ids,
1117+
position_ids=position_ids,
1118+
head_mask=head_mask)
1119+
11231120
sequence_output = outputs[0]
11241121

11251122
logits = self.qa_outputs(sequence_output)

pytorch_transformers/modeling_distilbert.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,10 @@ def tie_weights(self):
524524
self._tie_or_clone_weights(self.vocab_projector,
525525
self.distilbert.embeddings.word_embeddings)
526526

527-
def forward(self, input_ids, attention_mask=None, masked_lm_labels=None, head_mask=None):
527+
def forward(self, input_ids, attention_mask=None, head_mask=None, masked_lm_labels=None):
528528
dlbrt_output = self.distilbert(input_ids=input_ids,
529-
attention_mask=attention_mask,
530-
head_mask=head_mask)
529+
attention_mask=attention_mask,
530+
head_mask=head_mask)
531531
hidden_states = dlbrt_output[0] # (bs, seq_length, dim)
532532
prediction_logits = self.vocab_transform(hidden_states) # (bs, seq_length, dim)
533533
prediction_logits = gelu(prediction_logits) # (bs, seq_length, dim)
@@ -588,10 +588,10 @@ def __init__(self, config):
588588

589589
self.init_weights()
590590

591-
def forward(self, input_ids, attention_mask=None, labels=None, head_mask=None):
591+
def forward(self, input_ids, attention_mask=None, head_mask=None, labels=None):
592592
distilbert_output = self.distilbert(input_ids=input_ids,
593-
attention_mask=attention_mask,
594-
head_mask=head_mask)
593+
attention_mask=attention_mask,
594+
head_mask=head_mask)
595595
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
596596
pooled_output = hidden_state[:, 0] # (bs, dim)
597597
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
@@ -662,10 +662,10 @@ def __init__(self, config):
662662

663663
self.init_weights()
664664

665-
def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None, head_mask=None):
665+
def forward(self, input_ids, attention_mask=None, head_mask=None, start_positions=None, end_positions=None):
666666
distilbert_output = self.distilbert(input_ids=input_ids,
667-
attention_mask=attention_mask,
668-
head_mask=head_mask)
667+
attention_mask=attention_mask,
668+
head_mask=head_mask)
669669
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
670670

671671
hidden_states = self.dropout(hidden_states) # (bs, max_query_len, dim)

0 commit comments

Comments
 (0)