diff --git a/test/models/test_models.py b/test/models/test_models.py index c6b5d01acd..b917104c1b 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -5,6 +5,35 @@ from ..common.assets import get_asset_path +class TestModules(TorchtextTestCase): + def test_self_attn_mask(self): + from torchtext.models.roberta.modules import MultiheadSelfAttention + embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2 + mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads) + query = torch.ones((source_len, batch_size, embed_dim)) + query[0, ...] = 0 + key_padding_mask = torch.zeros((batch_size, source_len)) + attn_mask = torch.zeros((source_len, source_len)) + attn_mask[0][1] = -1e8 + with torch.no_grad(): + mha.input_projection.weight.fill_(1. / embed_dim) + mha.input_projection.bias.fill_(0.) + mha.output_projection.weight.fill_(1. / embed_dim) + mha.output_projection.bias.fill_(0.) + + # with attention mask + actual = mha(query, key_padding_mask, attn_mask) + expected = torch.tensor([[[0.0000, 0.0000, 0.0000, 0.0000]], + [[0.8938, 0.8938, 0.8938, 0.8938]]]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + # without attention mask + actual = mha(query, key_padding_mask) + expected = torch.tensor([[[0.5556, 0.5556, 0.5556, 0.5556]], + [[0.8938, 0.8938, 0.8938, 0.8938]]]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + class TestModels(TorchtextTestCase): def test_xlmr_base_output(self): asset_name = "xlmr.base.output.pt" diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index a2e0635a81..901e896270 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -99,7 +99,7 @@ def __init__( self.input_projection = nn.Linear(embed_dim, 3 * embed_dim) self.output_projection = nn.Linear(embed_dim, embed_dim) - def forward(self, query, key_padding_mask): + def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): target_length, batch_size, embed_dim = query.size() mask_batch_size, source_length = key_padding_mask.size() @@ -124,6 +124,18 @@ def forward(self, query, key_padding_mask): ) attn_weights = torch.bmm(q, k.transpose(1, 2)) + if attn_mask is not None: + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) + torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length)) + torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length)) + torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}") + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill( + attn_mask, + -1e8 if query.dtype == torch.float32 else -1e4 + ) + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim") torch._assert( @@ -209,18 +221,28 @@ def __init__( self.final_layer_norm = nn.LayerNorm(embedding_dim) self.normalize_before = normalize_before - def forward(self, input, key_padding_mask): + def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + if attn_mask is not None: + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) + torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}") + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill( + attn_mask, + -1e8 if input.dtype == torch.float32 else -1e4 + ) + if not hasattr(self, "normalize_before"): self.normalize_before = False + if self.normalize_before: x = self.attention_layer_norm(input) - attention = self.attention(x, key_padding_mask) + attention = self.attention(x, key_padding_mask, attn_mask) attention = self.dropout(attention) biased_input = input + attention x = self.final_layer_norm(biased_input) return self.residual_mlp(x) + biased_input else: - attention = self.attention(input, key_padding_mask) + attention = self.attention(input, key_padding_mask, attn_mask) attention = self.dropout(attention) biased_input = input + attention biased_input = self.attention_layer_norm(biased_input) @@ -267,7 +289,11 @@ def __init__( self.normalize_before = normalize_before self.return_all_layers = return_all_layers - def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: + def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]: + if attn_mask is not None: + torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype)) + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) + padding_mask = tokens.eq(self.padding_idx) token_embeddings = self.token_embedding(tokens) @@ -289,7 +315,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor states = [encoded] for layer in self.layers: - encoded = layer(encoded, padding_mask) + encoded = layer(encoded, padding_mask, attn_mask) states.append(encoded) if self.normalize_before: @@ -300,7 +326,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor return states else: for layer in self.layers: - encoded = layer(encoded, padding_mask) + encoded = layer(encoded, padding_mask, attn_mask) if self.normalize_before: encoded = self.embedding_layer_norm(encoded)