@@ -509,18 +509,18 @@ def _init_weights(self, module):
509509 Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
510510 See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
511511 :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
512- **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
513- Indices of positions of each input sequence tokens in the position embeddings.
514- Selected in the range ``[0, config.max_position_embeddings - 1]``.
512+ **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
513+ Mask to avoid performing attention on padding token indices.
514+ Mask values selected in ``[0, 1]``:
515+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
515516 **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
516517 Segment token indices to indicate first and second portions of the inputs.
517518 Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
518519 corresponds to a `sentence B` token
519520 (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
520- **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
521- Mask to avoid performing attention on padding token indices.
522- Mask values selected in ``[0, 1]``:
523- ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
521+ **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
522+ Indices of positions of each input sequence tokens in the position embeddings.
523+ Selected in the range ``[0, config.max_position_embeddings - 1]``.
524524 **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
525525 Mask to nullify selected heads of the self-attention modules.
526526 Mask values selected in ``[0, 1]``:
@@ -581,7 +581,7 @@ def _prune_heads(self, heads_to_prune):
581581 for layer , heads in heads_to_prune .items ():
582582 self .encoder .layer [layer ].attention .prune_heads (heads )
583583
584- def forward (self , input_ids , token_type_ids = None , attention_mask = None , position_ids = None , head_mask = None ):
584+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ):
585585 if attention_mask is None :
586586 attention_mask = torch .ones_like (input_ids )
587587 if token_type_ids is None :
@@ -684,10 +684,14 @@ def tie_weights(self):
684684 self ._tie_or_clone_weights (self .cls .predictions .decoder ,
685685 self .bert .embeddings .word_embeddings )
686686
687- def forward (self , input_ids , token_type_ids = None , attention_mask = None , masked_lm_labels = None ,
688- next_sentence_label = None , position_ids = None , head_mask = None ):
689- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
690- attention_mask = attention_mask , head_mask = head_mask )
687+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
688+ masked_lm_labels = None , next_sentence_label = None ):
689+
690+ outputs = self .bert (input_ids ,
691+ attention_mask = attention_mask ,
692+ token_type_ids = token_type_ids ,
693+ position_ids = position_ids ,
694+ head_mask = head_mask )
691695
692696 sequence_output , pooled_output = outputs [:2 ]
693697 prediction_scores , seq_relationship_score = self .cls (sequence_output , pooled_output )
@@ -752,10 +756,14 @@ def tie_weights(self):
752756 self ._tie_or_clone_weights (self .cls .predictions .decoder ,
753757 self .bert .embeddings .word_embeddings )
754758
755- def forward (self , input_ids , token_type_ids = None , attention_mask = None , masked_lm_labels = None ,
756- position_ids = None , head_mask = None ):
757- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
758- attention_mask = attention_mask , head_mask = head_mask )
759+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
760+ masked_lm_labels = None ):
761+
762+ outputs = self .bert (input_ids ,
763+ attention_mask = attention_mask ,
764+ token_type_ids = token_type_ids ,
765+ position_ids = position_ids ,
766+ head_mask = head_mask )
759767
760768 sequence_output = outputs [0 ]
761769 prediction_scores = self .cls (sequence_output )
@@ -809,10 +817,15 @@ def __init__(self, config):
809817
810818 self .init_weights ()
811819
812- def forward (self , input_ids , token_type_ids = None , attention_mask = None , next_sentence_label = None ,
813- position_ids = None , head_mask = None ):
814- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
815- attention_mask = attention_mask , head_mask = head_mask )
820+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
821+ next_sentence_label = None ):
822+
823+ outputs = self .bert (input_ids ,
824+ attention_mask = attention_mask ,
825+ token_type_ids = token_type_ids ,
826+ position_ids = position_ids ,
827+ head_mask = head_mask )
828+
816829 pooled_output = outputs [1 ]
817830
818831 seq_relationship_score = self .cls (pooled_output )
@@ -870,10 +883,15 @@ def __init__(self, config):
870883
871884 self .init_weights ()
872885
873- def forward (self , input_ids , token_type_ids = None , attention_mask = None , labels = None ,
874- position_ids = None , head_mask = None ):
875- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
876- attention_mask = attention_mask , head_mask = head_mask )
886+ def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
887+ position_ids = None , head_mask = None , labels = None ):
888+
889+ outputs = self .bert (input_ids ,
890+ attention_mask = attention_mask ,
891+ token_type_ids = token_type_ids ,
892+ position_ids = position_ids ,
893+ head_mask = head_mask )
894+
877895 pooled_output = outputs [1 ]
878896
879897 pooled_output = self .dropout (pooled_output )
@@ -896,45 +914,9 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
896914
897915@add_start_docstrings ("""Bert Model with a multiple choice classification head on top (a linear layer on top of
898916 the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """ ,
899- BERT_START_DOCSTRING )
917+ BERT_START_DOCSTRING , BERT_INPUTS_DOCSTRING )
900918class BertForMultipleChoice (BertPreTrainedModel ):
901919 r"""
902- Inputs:
903- **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
904- Indices of input sequence tokens in the vocabulary.
905- The second dimension of the input (`num_choices`) indicates the number of choices to score.
906- To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
907-
908- (a) For sequence pairs:
909-
910- ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
911-
912- ``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
913-
914- (b) For single sequences:
915-
916- ``tokens: [CLS] the dog is hairy . [SEP]``
917-
918- ``token_type_ids: 0 0 0 0 0 0 0``
919-
920- Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
921- See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
922- :func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
923- **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
924- Segment token indices to indicate first and second portions of the inputs.
925- The second dimension of the input (`num_choices`) indicates the number of choices to score.
926- Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
927- corresponds to a `sentence B` token
928- (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
929- **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
930- Mask to avoid performing attention on padding token indices.
931- The second dimension of the input (`num_choices`) indicates the number of choices to score.
932- Mask values selected in ``[0, 1]``:
933- ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
934- **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
935- Mask to nullify selected heads of the self-attention modules.
936- Mask values selected in ``[0, 1]``:
937- ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
938920 **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
939921 Labels for computing the multiple choice classification loss.
940922 Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
@@ -974,16 +956,21 @@ def __init__(self, config):
974956
975957 self .init_weights ()
976958
977- def forward (self , input_ids , token_type_ids = None , attention_mask = None , labels = None ,
978- position_ids = None , head_mask = None ):
959+ def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
960+ position_ids = None , head_mask = None , labels = None ):
979961 num_choices = input_ids .shape [1 ]
980962
981- flat_input_ids = input_ids .view (- 1 , input_ids .size (- 1 ))
982- flat_position_ids = position_ids .view (- 1 , position_ids .size (- 1 )) if position_ids is not None else None
983- flat_token_type_ids = token_type_ids .view (- 1 , token_type_ids .size (- 1 )) if token_type_ids is not None else None
984- flat_attention_mask = attention_mask .view (- 1 , attention_mask .size (- 1 )) if attention_mask is not None else None
985- outputs = self .bert (flat_input_ids , position_ids = flat_position_ids , token_type_ids = flat_token_type_ids ,
986- attention_mask = flat_attention_mask , head_mask = head_mask )
963+ input_ids = input_ids .view (- 1 , input_ids .size (- 1 ))
964+ attention_mask = attention_mask .view (- 1 , attention_mask .size (- 1 )) if attention_mask is not None else None
965+ token_type_ids = token_type_ids .view (- 1 , token_type_ids .size (- 1 )) if token_type_ids is not None else None
966+ position_ids = position_ids .view (- 1 , position_ids .size (- 1 )) if position_ids is not None else None
967+
968+ outputs = self .bert (input_ids ,
969+ attention_mask = attention_mask ,
970+ token_type_ids = token_type_ids ,
971+ position_ids = position_ids ,
972+ head_mask = head_mask )
973+
987974 pooled_output = outputs [1 ]
988975
989976 pooled_output = self .dropout (pooled_output )
@@ -1042,10 +1029,15 @@ def __init__(self, config):
10421029
10431030 self .init_weights ()
10441031
1045- def forward (self , input_ids , token_type_ids = None , attention_mask = None , labels = None ,
1046- position_ids = None , head_mask = None ):
1047- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
1048- attention_mask = attention_mask , head_mask = head_mask )
1032+ def forward (self , input_ids , attention_mask = None , token_type_ids = None ,
1033+ position_ids = None , head_mask = None , labels = None ):
1034+
1035+ outputs = self .bert (input_ids ,
1036+ attention_mask = attention_mask ,
1037+ token_type_ids = token_type_ids ,
1038+ position_ids = position_ids ,
1039+ head_mask = head_mask )
1040+
10491041 sequence_output = outputs [0 ]
10501042
10511043 sequence_output = self .dropout (sequence_output )
@@ -1116,10 +1108,15 @@ def __init__(self, config):
11161108
11171109 self .init_weights ()
11181110
1119- def forward (self , input_ids , token_type_ids = None , attention_mask = None , start_positions = None ,
1120- end_positions = None , position_ids = None , head_mask = None ):
1121- outputs = self .bert (input_ids , position_ids = position_ids , token_type_ids = token_type_ids ,
1122- attention_mask = attention_mask , head_mask = head_mask )
1111+ def forward (self , input_ids , attention_mask = None , token_type_ids = None , position_ids = None , head_mask = None ,
1112+ start_positions = None , end_positions = None ):
1113+
1114+ outputs = self .bert (input_ids ,
1115+ attention_mask = attention_mask ,
1116+ token_type_ids = token_type_ids ,
1117+ position_ids = position_ids ,
1118+ head_mask = head_mask )
1119+
11231120 sequence_output = outputs [0 ]
11241121
11251122 logits = self .qa_outputs (sequence_output )
0 commit comments