Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,35 @@
from ..common.assets import get_asset_path


class TestModules(TorchtextTestCase):
def test_self_attn_mask(self):
from torchtext.models.roberta.modules import MultiheadSelfAttention
embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2
mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads)
query = torch.ones((source_len, batch_size, embed_dim))
query[0, ...] = 0
key_padding_mask = torch.zeros((batch_size, source_len))
attn_mask = torch.zeros((source_len, source_len))
attn_mask[0][1] = -1e8
with torch.no_grad():
mha.input_projection.weight.fill_(1. / embed_dim)
mha.input_projection.bias.fill_(0.)
mha.output_projection.weight.fill_(1. / embed_dim)
mha.output_projection.bias.fill_(0.)

# with attention mask
actual = mha(query, key_padding_mask, attn_mask)
expected = torch.tensor([[[0.0000, 0.0000, 0.0000, 0.0000]],
[[0.8938, 0.8938, 0.8938, 0.8938]]])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

# without attention mask
actual = mha(query, key_padding_mask)
expected = torch.tensor([[[0.5556, 0.5556, 0.5556, 0.5556]],
[[0.8938, 0.8938, 0.8938, 0.8938]]])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)


class TestModels(TorchtextTestCase):
def test_xlmr_base_output(self):
asset_name = "xlmr.base.output.pt"
Expand Down
40 changes: 33 additions & 7 deletions torchtext/models/roberta/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
self.input_projection = nn.Linear(embed_dim, 3 * embed_dim)
self.output_projection = nn.Linear(embed_dim, embed_dim)

def forward(self, query, key_padding_mask):
def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
target_length, batch_size, embed_dim = query.size()
mask_batch_size, source_length = key_padding_mask.size()

Expand All @@ -124,6 +124,18 @@ def forward(self, query, key_padding_mask):
)

attn_weights = torch.bmm(q, k.transpose(1, 2))
if attn_mask is not None:
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length))
torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length))
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
if attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(
attn_mask,
-1e8 if query.dtype == torch.float32 else -1e4
)
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask

torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim")
torch._assert(
Expand Down Expand Up @@ -209,18 +221,28 @@ def __init__(
self.final_layer_norm = nn.LayerNorm(embedding_dim)
self.normalize_before = normalize_before

def forward(self, input, key_padding_mask):
def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
if attn_mask is not None:
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))
torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}")
if attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(
attn_mask,
-1e8 if input.dtype == torch.float32 else -1e4
)

if not hasattr(self, "normalize_before"):
self.normalize_before = False

if self.normalize_before:
x = self.attention_layer_norm(input)
attention = self.attention(x, key_padding_mask)
attention = self.attention(x, key_padding_mask, attn_mask)
attention = self.dropout(attention)
biased_input = input + attention
x = self.final_layer_norm(biased_input)
return self.residual_mlp(x) + biased_input
else:
attention = self.attention(input, key_padding_mask)
attention = self.attention(input, key_padding_mask, attn_mask)
attention = self.dropout(attention)
biased_input = input + attention
biased_input = self.attention_layer_norm(biased_input)
Expand Down Expand Up @@ -267,7 +289,11 @@ def __init__(
self.normalize_before = normalize_before
self.return_all_layers = return_all_layers

def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]:
if attn_mask is not None:
torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype))
torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim()))

padding_mask = tokens.eq(self.padding_idx)

token_embeddings = self.token_embedding(tokens)
Expand All @@ -289,7 +315,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor
states = [encoded]

for layer in self.layers:
encoded = layer(encoded, padding_mask)
encoded = layer(encoded, padding_mask, attn_mask)
states.append(encoded)

if self.normalize_before:
Expand All @@ -300,7 +326,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor
return states
else:
for layer in self.layers:
encoded = layer(encoded, padding_mask)
encoded = layer(encoded, padding_mask, attn_mask)

if self.normalize_before:
encoded = self.embedding_layer_norm(encoded)
Expand Down