Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@ebsmothers
Copy link
Contributor

The bool attention mask was being cast to float incorrectly. This diff fixes the cast and creates an additional test case for bool masks

Summary:
Rather than raise an exception whenever head_dim != 64, we can just infer the scaling value and continue to provide a warning.

Also add an assertion in case embed_dim is not a multiple of num_heads (in which case forward will break).

Reviewed By: parmeet

Differential Revision: D32193989

fbshipit-source-id: 30f68c55f3ec37932252c77c355ae55b8bf34ded
if attn_mask.dtype == torch.bool:
new_attn_mask = torch.zeros_like(attn_mask, dtype=input.dtype)
new_attn_mask.masked_fill_(attn_mask, -1e8 if input.dtype == torch.float32 else -1e4)
attn_mask = new_attn_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is necessary to add the conversion here, since we have already added inside TransformerEncoderLayer? Following the same argument, perhaps we could also remove it from TransformerEncoderLayer, since we know that it is going to be passed to MultiHeadSelfAttention that would do this conversion. This was also discussed a bit in here #1435 (comment). I feel, we could avoid the redundancy here. wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I think it makes sense to perform the bool-to-float conversion only in MultiHeadSelfAttention, since that is where the mask is actually used. Then we can just keep the checks in TransformerEncoder and TransformerEncoderLayer to ensure the attention mask is either bool or float.

@ebsmothers ebsmothers marked this pull request as ready for review December 8, 2021 17:23
@codecov
Copy link

codecov bot commented Dec 8, 2021

Codecov Report

Merging #1454 (b584f30) into main (9f2fb3f) will increase coverage by 0.17%.
The diff coverage is 75.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1454      +/-   ##
==========================================
+ Coverage   86.35%   86.52%   +0.17%     
==========================================
  Files          58       58              
  Lines        2220     2219       -1     
==========================================
+ Hits         1917     1920       +3     
+ Misses        303      299       -4     
Impacted Files Coverage Δ
torchtext/models/roberta/modules.py 84.75% <75.00%> (+2.33%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9f2fb3f...b584f30. Read the comment docs.

Copy link
Contributor

@parmeet parmeet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing the issue!

@parmeet parmeet merged commit a074cb2 into pytorch:main Dec 8, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants