-
Notifications
You must be signed in to change notification settings - Fork 31.2k
[2.0] Reodering arguments for torch jit #1010 and future TF2.0 compatibility #1195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
16296da to
7fba47b
Compare
Codecov Report
@@ Coverage Diff @@
## master #1195 +/- ##
==========================================
- Coverage 80.83% 80.42% -0.41%
==========================================
Files 46 46
Lines 7878 7892 +14
==========================================
- Hits 6368 6347 -21
- Misses 1510 1545 +35
Continue to review full report at Codecov.
|
| outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, | ||
| attention_mask=attention_mask, head_mask=head_mask) | ||
| def forward(self, input_ids, attention_mask=None, token_type_ids=None, | ||
| position_ids=None, head_mask=None, labels=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to solve #1010
| if attention_mask is not None: | ||
| # Apply the attention mask | ||
| w = w + attention_mask | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition!
|
|
||
| if attention_mask is not None: | ||
| # Apply the attention mask | ||
| w = w + attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition here too
| model.eval() | ||
| sequence_output, pooled_output = model(input_ids, token_type_ids, input_mask) | ||
| sequence_output, pooled_output = model(input_ids, token_type_ids) | ||
| sequence_output, pooled_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like this is clearer than before! Looks great.
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to solve the problems related to JIT. I feel it's a great addition as the order of inputs now accurately depicts the order of importance.
|
Ok merging |
Torch jit (cf #1010) and TF 2.0 (cf #1104) are more strict than PyTorch on having a specific order of arguments for easy use.
This PR refactor the order of the keyword arguments to make them as natural as possible.
This will be a breaking change for people using positional order to input keyword arguments in the forward pass of the models, hence is delayed to the 2.0 release.