-
Notifications
You must be signed in to change notification settings - Fork 814
Fix bool attention mask in transformer encoder #1454
Conversation
Summary: Rather than raise an exception whenever head_dim != 64, we can just infer the scaling value and continue to provide a warning. Also add an assertion in case embed_dim is not a multiple of num_heads (in which case forward will break). Reviewed By: parmeet Differential Revision: D32193989 fbshipit-source-id: 30f68c55f3ec37932252c77c355ae55b8bf34ded
torchtext/models/roberta/modules.py
Outdated
| if attn_mask.dtype == torch.bool: | ||
| new_attn_mask = torch.zeros_like(attn_mask, dtype=input.dtype) | ||
| new_attn_mask.masked_fill_(attn_mask, -1e8 if input.dtype == torch.float32 else -1e4) | ||
| attn_mask = new_attn_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is necessary to add the conversion here, since we have already added inside TransformerEncoderLayer? Following the same argument, perhaps we could also remove it from TransformerEncoderLayer, since we know that it is going to be passed to MultiHeadSelfAttention that would do this conversion. This was also discussed a bit in here #1435 (comment). I feel, we could avoid the redundancy here. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I think it makes sense to perform the bool-to-float conversion only in MultiHeadSelfAttention, since that is where the mask is actually used. Then we can just keep the checks in TransformerEncoder and TransformerEncoderLayer to ensure the attention mask is either bool or float.
…to attn_mask_fix
…to attn_mask_fix
Codecov Report
@@ Coverage Diff @@
## main #1454 +/- ##
==========================================
+ Coverage 86.35% 86.52% +0.17%
==========================================
Files 58 58
Lines 2220 2219 -1
==========================================
+ Hits 1917 1920 +3
+ Misses 303 299 -4
Continue to review full report at Codecov.
|
parmeet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing the issue!
The bool attention mask was being cast to float incorrectly. This diff fixes the cast and creates an additional test case for bool masks