From 15f0fdf7cfec41ef981b1b7eca2d2298cff03192 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 12 Nov 2024 13:39:42 -0800 Subject: [PATCH 1/3] Fix flaky ET attention test --- extension/llm/modules/test/test_attention.py | 15 ++++++++++++--- pytest.ini | 1 - 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 9ae136a2137..d9cabf0cb1f 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -94,7 +94,10 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue(torch.allclose(et_res, tt_res)) + self.assertTrue( + torch.allclose(et_res, tt_res), + msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", + ) # test with kv cache self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) @@ -136,7 +139,10 @@ def test_attention_export(self): ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res, tt_res)) + self.assertTrue( + torch.allclose(et_res, tt_res), + msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", + ) # TODO: KV cache. @@ -162,6 +168,9 @@ def test_attention_executorch(self): et_res = method.execute((self.x, self.x, self.input_pos)) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06)) + self.assertTrue( + torch.allclose(et_res[0], tt_res, atol=1e-05), + msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}", + ) # TODO: KV cache. diff --git a/pytest.ini b/pytest.ini index 03c015c3979..a5041504aef 100644 --- a/pytest.ini +++ b/pytest.ini @@ -39,7 +39,6 @@ addopts = backends/xnnpack/test # extension/ extension/llm/modules/test - --ignore=extension/llm/modules/test/test_mha.py extension/pybindings/test # Runtime runtime From d066a259d7f6ac62dc1f66b0149c3ab7ec6c1409 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 12 Nov 2024 14:05:22 -0800 Subject: [PATCH 2/3] Use assert_close --- extension/llm/modules/test/test_attention.py | 24 ++++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index d9cabf0cb1f..66f745a0089 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -13,6 +13,7 @@ MultiHeadAttention as ETMultiHeadAttention, ) from executorch.runtime import Runtime +from torch.testing import assert_close from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention @@ -94,8 +95,9 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - self.assertTrue( - torch.allclose(et_res, tt_res), + assert_close( + et_res, + tt_res, msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", ) @@ -127,7 +129,12 @@ def test_attention_eager(self): tt_res = self.tt_mha( self.x, self.x, input_pos=next_input_pos ) # Self attention with input pos. - self.assertTrue(torch.allclose(et_res, tt_res)) + + assert_close( + et_res, + tt_res, + msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", + ) def test_attention_export(self): # Self attention. @@ -139,8 +146,10 @@ def test_attention_export(self): ) et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue( - torch.allclose(et_res, tt_res), + + assert_close( + et_res, + tt_res, msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", ) @@ -168,8 +177,9 @@ def test_attention_executorch(self): et_res = method.execute((self.x, self.x, self.input_pos)) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - self.assertTrue( - torch.allclose(et_res[0], tt_res, atol=1e-05), + assert_close( + et_res[0], + tt_res, msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}", ) From a1baf0d3a47410a3365d6742864e287b62bb5768 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 12 Nov 2024 15:24:01 -0800 Subject: [PATCH 3/3] Remove msg from assert_close Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- extension/llm/modules/test/test_attention.py | 24 ++++---------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/extension/llm/modules/test/test_attention.py b/extension/llm/modules/test/test_attention.py index 66f745a0089..565e8c67d75 100644 --- a/extension/llm/modules/test/test_attention.py +++ b/extension/llm/modules/test/test_attention.py @@ -95,11 +95,7 @@ def test_attention_eager(self): et_res = self.et_mha(self.x, self.x) # Self attention. tt_res = self.tt_mha(self.x, self.x) # Self attention. - assert_close( - et_res, - tt_res, - msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", - ) + assert_close(et_res, tt_res) # test with kv cache self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20) @@ -130,11 +126,7 @@ def test_attention_eager(self): self.x, self.x, input_pos=next_input_pos ) # Self attention with input pos. - assert_close( - et_res, - tt_res, - msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", - ) + assert_close(et_res, tt_res) def test_attention_export(self): # Self attention. @@ -147,11 +139,7 @@ def test_attention_export(self): et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - assert_close( - et_res, - tt_res, - msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}", - ) + assert_close(et_res, tt_res) # TODO: KV cache. @@ -177,10 +165,6 @@ def test_attention_executorch(self): et_res = method.execute((self.x, self.x, self.input_pos)) tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos) - assert_close( - et_res[0], - tt_res, - msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}", - ) + assert_close(et_res[0], tt_res) # TODO: KV cache.