Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 9314b44

Browse files
authored
add attention mask to transformer encoder modules (#1435)
1 parent ba20fc5 commit 9314b44

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

test/models/test_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,35 @@
55
from ..common.assets import get_asset_path
66

77

8+
class TestModules(TorchtextTestCase):
9+
def test_self_attn_mask(self):
10+
from torchtext.models.roberta.modules import MultiheadSelfAttention
11+
embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2
12+
mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads)
13+
query = torch.ones((source_len, batch_size, embed_dim))
14+
query[0, ...] = 0
15+
key_padding_mask = torch.zeros((batch_size, source_len))
16+
attn_mask = torch.zeros((source_len, source_len))
17+
attn_mask[0][1] = -1e8
18+
with torch.no_grad():
19+
mha.input_projection.weight.fill_(1. / embed_dim)
20+
mha.input_projection.bias.fill_(0.)
21+
mha.output_projection.weight.fill_(1. / embed_dim)
22+
mha.output_projection.bias.fill_(0.)
23+
24+
# with attention mask
25+
actual = mha(query, key_padding_mask, attn_mask)
26+
expected = torch.tensor([[[0.0000, 0.0000, 0.0000, 0.0000]],
27+
[[0.8938, 0.8938, 0.8938, 0.8938]]])
28+
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
29+
30+
# without attention mask
31+
actual = mha(query, key_padding_mask)
32+
expected = torch.tensor([[[0.5556, 0.5556, 0.5556, 0.5556]],
33+
[[0.8938, 0.8938, 0.8938, 0.8938]]])
34+
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
35+
36+
837
class TestModels(TorchtextTestCase):
938
def test_xlmr_base_output(self):
1039
asset_name = "xlmr.base.output.pt"

torchtext/models/roberta/modules.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __init__(
9999
self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
100100
self.output_projection = nn.Linear(embed_dim, embed_dim)
101101

102-
def forward(self, query, key_padding_mask):
102+
def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
103103
target_length, batch_size, embed_dim = query.size()
104104
mask_batch_size, source_length = key_padding_mask.size()
105105

@@ -124,6 +124,18 @@ def forward(self, query, key_padding_mask):
124124
)
125125

126126
attn_weights = torch.bmm(q, k.transpose(1, 2))
127+
if attn_mask is not None:
128+
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
129+
torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length))
130+
torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length))
131+
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
132+
if attn_mask.dtype == torch.bool:
133+
attn_mask = attn_mask.masked_fill(
134+
attn_mask,
135+
-1e8 if query.dtype == torch.float32 else -1e4
136+
)
137+
attn_mask = attn_mask.unsqueeze(0)
138+
attn_weights += attn_mask
127139

128140
torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
129141
torch._assert(
@@ -209,18 +221,28 @@ def __init__(
209221
self.final_layer_norm = nn.LayerNorm(embedding_dim)
210222
self.normalize_before = normalize_before
211223

212-
def forward(self, input, key_padding_mask):
224+
def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
225+
if attn_mask is not None:
226+
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
227+
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
228+
if attn_mask.dtype == torch.bool:
229+
attn_mask = attn_mask.masked_fill(
230+
attn_mask,
231+
-1e8 if input.dtype == torch.float32 else -1e4
232+
)
233+
213234
if not hasattr(self, "normalize_before"):
214235
self.normalize_before = False
236+
215237
if self.normalize_before:
216238
x = self.attention_layer_norm(input)
217-
attention = self.attention(x, key_padding_mask)
239+
attention = self.attention(x, key_padding_mask, attn_mask)
218240
attention = self.dropout(attention)
219241
biased_input = input + attention
220242
x = self.final_layer_norm(biased_input)
221243
return self.residual_mlp(x) + biased_input
222244
else:
223-
attention = self.attention(input, key_padding_mask)
245+
attention = self.attention(input, key_padding_mask, attn_mask)
224246
attention = self.dropout(attention)
225247
biased_input = input + attention
226248
biased_input = self.attention_layer_norm(biased_input)
@@ -267,7 +289,11 @@ def __init__(
267289
self.normalize_before = normalize_before
268290
self.return_all_layers = return_all_layers
269291

270-
def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
292+
def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
293+
if attn_mask is not None:
294+
torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype))
295+
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
296+
271297
padding_mask = tokens.eq(self.padding_idx)
272298

273299
token_embeddings = self.token_embedding(tokens)
@@ -289,7 +315,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor
289315
states = [encoded]
290316

291317
for layer in self.layers:
292-
encoded = layer(encoded, padding_mask)
318+
encoded = layer(encoded, padding_mask, attn_mask)
293319
states.append(encoded)
294320

295321
if self.normalize_before:
@@ -300,7 +326,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor
300326
return states
301327
else:
302328
for layer in self.layers:
303-
encoded = layer(encoded, padding_mask)
329+
encoded = layer(encoded, padding_mask, attn_mask)
304330

305331
if self.normalize_before:
306332
encoded = self.embedding_layer_norm(encoded)

0 commit comments

Comments
 (0)