From a134ef9406ae3c4cdda3205cda75d092b2654f68 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sun, 7 Nov 2021 23:57:32 -0500 Subject: [PATCH 1/6] add attention mask to transformer encoder modules --- torchtext/models/roberta/modules.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index a2e0635a81..96c5a0816a 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -99,7 +99,7 @@ def __init__( self.input_projection = nn.Linear(embed_dim, 3 * embed_dim) self.output_projection = nn.Linear(embed_dim, embed_dim) - def forward(self, query, key_padding_mask): + def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): target_length, batch_size, embed_dim = query.size() mask_batch_size, source_length = key_padding_mask.size() @@ -124,6 +124,9 @@ def forward(self, query, key_padding_mask): ) attn_weights = torch.bmm(q, k.transpose(1, 2)) + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask torch._assert(attn_weights.dim() == 3, "Unexpected attn_weights dim") torch._assert( @@ -209,9 +212,16 @@ def __init__( self.final_layer_norm = nn.LayerNorm(embedding_dim) self.normalize_before = normalize_before - def forward(self, input, key_padding_mask): + def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + if attn_mask is not None: + attn_mask = attn_mask.masked_fill( + attn_mask.to(torch.bool), + -1e8 if input.dtype == torch.float32 else -1e4 + ) + if not hasattr(self, "normalize_before"): self.normalize_before = False + if self.normalize_before: x = self.attention_layer_norm(input) attention = self.attention(x, key_padding_mask) @@ -267,7 +277,7 @@ def __init__( self.normalize_before = normalize_before self.return_all_layers = return_all_layers - def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]: + def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]: padding_mask = tokens.eq(self.padding_idx) token_embeddings = self.token_embedding(tokens) @@ -289,7 +299,7 @@ def forward(self, tokens: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor states = [encoded] for layer in self.layers: - encoded = layer(encoded, padding_mask) + encoded = layer(encoded, padding_mask, attn_mask) states.append(encoded) if self.normalize_before: From 1f772ae9e9380ee7273dedc21e043c292bc5df89 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 8 Nov 2021 00:01:03 -0500 Subject: [PATCH 2/6] minor fix --- torchtext/models/roberta/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 96c5a0816a..bcffda9ae2 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -310,7 +310,7 @@ def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None return states else: for layer in self.layers: - encoded = layer(encoded, padding_mask) + encoded = layer(encoded, padding_mask, attn_mask) if self.normalize_before: encoded = self.embedding_layer_norm(encoded) From 24e8b3ca81a90d577c1f3e82f6dc139b3e8cd94d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 8 Nov 2021 00:03:33 -0500 Subject: [PATCH 3/6] fix in forward --- torchtext/models/roberta/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index bcffda9ae2..e68d2ea2a3 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -224,13 +224,13 @@ def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask if self.normalize_before: x = self.attention_layer_norm(input) - attention = self.attention(x, key_padding_mask) + attention = self.attention(x, key_padding_mask, attn_mask) attention = self.dropout(attention) biased_input = input + attention x = self.final_layer_norm(biased_input) return self.residual_mlp(x) + biased_input else: - attention = self.attention(input, key_padding_mask) + attention = self.attention(input, key_padding_mask, attn_mask) attention = self.dropout(attention) biased_input = input + attention biased_input = self.attention_layer_norm(biased_input) From f4ebc15060b41563eafcaa4f97f784579ed7bf5f Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Mon, 8 Nov 2021 11:09:43 -0500 Subject: [PATCH 4/6] add assertions --- torchtext/models/roberta/modules.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index e68d2ea2a3..03f95362d1 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -125,6 +125,7 @@ def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask @@ -214,6 +215,8 @@ def __init__( def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if attn_mask is not None: + torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype)) + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) attn_mask = attn_mask.masked_fill( attn_mask.to(torch.bool), -1e8 if input.dtype == torch.float32 else -1e4 @@ -278,6 +281,10 @@ def __init__( self.return_all_layers = return_all_layers def forward(self, tokens: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> Union[torch.Tensor, List[torch.Tensor]]: + if attn_mask is not None: + torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype)) + torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) + padding_mask = tokens.eq(self.padding_idx) token_embeddings = self.token_embedding(tokens) From 695f9d23a7f1d178b58a30a0be8f4c318df3eda3 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Nov 2021 00:16:49 -0500 Subject: [PATCH 5/6] add tests --- test/models/test_models.py | 29 +++++++++++++++++++++++++++++ torchtext/models/roberta/modules.py | 3 ++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/models/test_models.py b/test/models/test_models.py index c6b5d01acd..b917104c1b 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -5,6 +5,35 @@ from ..common.assets import get_asset_path +class TestModules(TorchtextTestCase): + def test_self_attn_mask(self): + from torchtext.models.roberta.modules import MultiheadSelfAttention + embed_dim, batch_size, num_heads, source_len = 4, 1, 2, 2 + mha = MultiheadSelfAttention(embed_dim=embed_dim, num_heads=num_heads) + query = torch.ones((source_len, batch_size, embed_dim)) + query[0, ...] = 0 + key_padding_mask = torch.zeros((batch_size, source_len)) + attn_mask = torch.zeros((source_len, source_len)) + attn_mask[0][1] = -1e8 + with torch.no_grad(): + mha.input_projection.weight.fill_(1. / embed_dim) + mha.input_projection.bias.fill_(0.) + mha.output_projection.weight.fill_(1. / embed_dim) + mha.output_projection.bias.fill_(0.) + + # with attention mask + actual = mha(query, key_padding_mask, attn_mask) + expected = torch.tensor([[[0.0000, 0.0000, 0.0000, 0.0000]], + [[0.8938, 0.8938, 0.8938, 0.8938]]]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + # without attention mask + actual = mha(query, key_padding_mask) + expected = torch.tensor([[[0.5556, 0.5556, 0.5556, 0.5556]], + [[0.8938, 0.8938, 0.8938, 0.8938]]]) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + class TestModels(TorchtextTestCase): def test_xlmr_base_output(self): asset_name = "xlmr.base.output.pt" diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 03f95362d1..38466a0a07 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -126,6 +126,8 @@ def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask attn_weights = torch.bmm(q, k.transpose(1, 2)) if attn_mask is not None: torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) + torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length)) + torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length)) attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask @@ -215,7 +217,6 @@ def __init__( def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if attn_mask is not None: - torch._assert(attn_mask.dtype == torch.bool, "Expected attn_mask dtype as `torch.bool` but got {}".format(attn_mask.dtype)) torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) attn_mask = attn_mask.masked_fill( attn_mask.to(torch.bool), From afe6371dc3363e6e8989078b4161c0e5b91146f4 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 9 Nov 2021 00:46:06 -0500 Subject: [PATCH 6/6] add valid types for attn_mask --- torchtext/models/roberta/modules.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/torchtext/models/roberta/modules.py b/torchtext/models/roberta/modules.py index 38466a0a07..901e896270 100644 --- a/torchtext/models/roberta/modules.py +++ b/torchtext/models/roberta/modules.py @@ -128,6 +128,12 @@ def forward(self, query: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) torch._assert(attn_mask.size(0) == target_length, "attn_mask shape didn't match for target length {}".format(target_length)) torch._assert(attn_mask.size(1) == source_length, "attn_mask shape didn't match for source length {}".format(source_length)) + torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}") + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill( + attn_mask, + -1e8 if query.dtype == torch.float32 else -1e4 + ) attn_mask = attn_mask.unsqueeze(0) attn_weights += attn_mask @@ -218,10 +224,12 @@ def __init__( def forward(self, input: torch.Tensor, key_padding_mask: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): if attn_mask is not None: torch._assert(attn_mask.dim() == 2, "Expected attn_mask of dim 2 but got {}".format(attn_mask.dim())) - attn_mask = attn_mask.masked_fill( - attn_mask.to(torch.bool), - -1e8 if input.dtype == torch.float32 else -1e4 - ) + torch._assert(attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, f"Only float or bool types are supported for attn_mask not {attn_mask.dtype}") + if attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill( + attn_mask, + -1e8 if input.dtype == torch.float32 else -1e4 + ) if not hasattr(self, "normalize_before"): self.normalize_before = False