diff --git a/test/models/test_models.py b/test/models/test_models.py index b917104c1b..91b3eb999a 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -22,15 +22,9 @@ def test_self_attn_mask(self): mha.output_projection.bias.fill_(0.) # with attention mask - actual = mha(query, key_padding_mask, attn_mask) - expected = torch.tensor([[[0.0000, 0.0000, 0.0000, 0.0000]], - [[0.8938, 0.8938, 0.8938, 0.8938]]]) - torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) - - # without attention mask - actual = mha(query, key_padding_mask) - expected = torch.tensor([[[0.5556, 0.5556, 0.5556, 0.5556]], - [[0.8938, 0.8938, 0.8938, 0.8938]]]) + output = mha(query, key_padding_mask, attn_mask) + actual = output[0].flatten() + expected = torch.tensor([0., 0., 0., 0]) torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)