1313 MultiHeadAttention as ETMultiHeadAttention ,
1414)
1515from executorch .runtime import Runtime
16+ from torch .testing import assert_close
1617from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
1718from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
1819
@@ -92,8 +93,9 @@ def test_attention_eager(self):
9293 et_res = self .et_mha (self .x , self .x ) # Self attention.
9394 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
9495
95- self .assertTrue (
96- torch .allclose (et_res , tt_res ),
96+ assert_close (
97+ et_res ,
98+ tt_res ,
9799 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
98100 )
99101
@@ -104,7 +106,11 @@ def test_attention_eager(self):
104106 # et_res = self.et_mha(self.x, self.x) # Self attention.
105107 # tt_res = self.tt_mha(self.x, self.x) # Self attention.
106108
107- # self.assertTrue(torch.allclose(et_res, tt_res))
109+ # assert_close(
110+ # et_res,
111+ # tt_res,
112+ # msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}"
113+ # )
108114
109115 def test_attention_export (self ):
110116 # Self attention.
@@ -116,8 +122,10 @@ def test_attention_export(self):
116122 )
117123 et_res = et_mha_ep .module ()(self .x , self .x )
118124 tt_res = self .tt_mha (self .x , self .x )
119- self .assertTrue (
120- torch .allclose (et_res , tt_res ),
125+
126+ assert_close (
127+ et_res ,
128+ tt_res ,
121129 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
122130 )
123131
@@ -145,8 +153,9 @@ def test_attention_executorch(self):
145153 et_res = method .execute ((self .x , self .x ))
146154 tt_res = self .tt_mha (self .x , self .x )
147155
148- self .assertTrue (
149- torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
156+ assert_close (
157+ et_res [0 ],
158+ tt_res ,
150159 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
151160 )
152161
0 commit comments