1313 MultiHeadAttention as ETMultiHeadAttention ,
1414)
1515from executorch .runtime import Runtime
16+ from torch .testing import assert_close
1617from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
1718from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
1819
@@ -94,8 +95,9 @@ def test_attention_eager(self):
9495 et_res = self .et_mha (self .x , self .x ) # Self attention.
9596 tt_res = self .tt_mha (self .x , self .x ) # Self attention.
9697
97- self .assertTrue (
98- torch .allclose (et_res , tt_res ),
98+ assert_close (
99+ et_res ,
100+ tt_res ,
99101 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
100102 )
101103
@@ -127,7 +129,12 @@ def test_attention_eager(self):
127129 tt_res = self .tt_mha (
128130 self .x , self .x , input_pos = next_input_pos
129131 ) # Self attention with input pos.
130- self .assertTrue (torch .allclose (et_res , tt_res ))
132+
133+ assert_close (
134+ et_res ,
135+ tt_res ,
136+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
137+ )
131138
132139 def test_attention_export (self ):
133140 # Self attention.
@@ -139,8 +146,10 @@ def test_attention_export(self):
139146 )
140147 et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
141148 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
142- self .assertTrue (
143- torch .allclose (et_res , tt_res ),
149+
150+ assert_close (
151+ et_res ,
152+ tt_res ,
144153 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
145154 )
146155
@@ -168,8 +177,9 @@ def test_attention_executorch(self):
168177 et_res = method .execute ((self .x , self .x , self .input_pos ))
169178 tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
170179
171- self .assertTrue (
172- torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
180+ assert_close (
181+ et_res [0 ],
182+ tt_res ,
173183 msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
174184 )
175185
0 commit comments