Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 21d313b

Browse files
committed
Merge branch 'attn_mask_fix' of https://github.com/ebsmothers/text into attn_mask_fix
2 parents b061086 + 667321b commit 21d313b

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

test/models/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def test_self_attn_mask(self):
1515
query[0, ...] = 0
1616
key_padding_mask = torch.zeros((batch_size, source_len))
1717
float_attn_mask = torch.zeros((source_len, source_len))
18-
bool_attn_mask = float_attn_mask.to(dtype=bool)
1918
float_attn_mask[0][1] = -1e8
19+
bool_attn_mask = float_attn_mask.to(dtype=bool)
2020
with torch.no_grad():
2121
mha.input_projection.weight.fill_(1. / embed_dim)
2222
mha.input_projection.bias.fill_(0.)

torchtext/models/roberta/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def __init__(
286286
def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
287287
if attn_mask is not None:
288288
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
289+
289290
padding_mask = tokens.eq(self.padding_idx)
290291

291292
token_embeddings = self.token_embedding(tokens)

0 commit comments

Comments
 (0)