@@ -33,6 +33,31 @@ def GetResource(self, filepath):
3333 file_text = f .read ()
3434 return file_text
3535
36+ def SingleFftAutoScaleTest (self , filename ):
37+ lines = self .GetResource (filename ).splitlines ()
38+ func = tf .function (fft_ops .fft_auto_scale )
39+ input_size = len (lines [0 ].split ())
40+ concrete_function = func .get_concrete_function (
41+ tf .TensorSpec (input_size , dtype = tf .int16 ))
42+ interpreter = util .get_tflm_interpreter (concrete_function , func )
43+ i = 0
44+ while i < len (lines ):
45+ in_frame = np .array ([int (j ) for j in lines [i ].split ()], dtype = np .int16 )
46+ out_frame_exp = [int (j ) for j in lines [i + 1 ].split ()]
47+ scale_exp = [int (j ) for j in lines [i + 2 ].split ()]
48+ # TFLM
49+ interpreter .set_input (in_frame , 0 )
50+ interpreter .invoke ()
51+ out_frame = interpreter .get_output (0 )
52+ scale = interpreter .get_output (1 )
53+ self .assertAllEqual (out_frame_exp , out_frame )
54+ self .assertEqual (scale_exp , scale )
55+ # TF
56+ out_frame , scale = self .evaluate (fft_ops .fft_auto_scale (in_frame ))
57+ self .assertAllEqual (out_frame_exp , out_frame )
58+ self .assertEqual (scale_exp , scale )
59+ i += 3
60+
3661 def SingleRfftTest (self , filename ):
3762 lines = self .GetResource (filename ).splitlines ()
3863 args = lines [0 ].split ()
@@ -43,8 +68,6 @@ def SingleRfftTest(self, filename):
4368 tf .TensorSpec (input_size , dtype = tf .int16 ), fft_length )
4469 # TODO(b/286252893): make test more robust (vs scipy)
4570 interpreter = util .get_tflm_interpreter (concrete_function , func )
46- input_details = interpreter .get_input_details ()
47- output_details = interpreter .get_output_details ()
4871 # Skip line 0, which contains the configuration params.
4972 # Read lines in pairs <input, expected>
5073 i = 1
@@ -53,9 +76,9 @@ def SingleRfftTest(self, filename):
5376 out_frame_exp = [int (j ) for j in lines [i + 1 ].split ()]
5477 # Compare TFLM inference against the expected golden values
5578 # TODO(b/286252893): validate usage of testing vs interpreter here
56- interpreter .set_tensor ( input_details [ 0 ][ 'index' ], in_frame )
79+ interpreter .set_input ( in_frame , 0 )
5780 interpreter .invoke ()
58- out_frame = interpreter .get_tensor ( output_details [ 0 ][ 'index' ] )
81+ out_frame = interpreter .get_output ( 0 )
5982 self .assertAllEqual (out_frame_exp , out_frame )
6083 # TF
6184 out_frame = self .evaluate (fft_ops .rfft (in_frame , fft_length ))
@@ -83,11 +106,9 @@ def MultiDimRfftTest(self, filename):
83106 concrete_function = func .get_concrete_function (
84107 tf .TensorSpec (np .shape (in_frames ), dtype = tf .int16 ), fft_length )
85108 interpreter = util .get_tflm_interpreter (concrete_function , func )
86- input_details = interpreter .get_input_details ()
87- output_details = interpreter .get_output_details ()
88- interpreter .set_tensor (input_details [0 ]['index' ], in_frames )
109+ interpreter .set_input (in_frames , 0 )
89110 interpreter .invoke ()
90- out_frame = interpreter .get_tensor ( output_details [ 0 ][ 'index' ] )
111+ out_frame = interpreter .get_output ( 0 )
91112 self .assertAllEqual (out_frames_exp , out_frame )
92113 # TF
93114 out_frames = self .evaluate (fft_ops .rfft (in_frames , fft_length ))
@@ -204,6 +225,12 @@ def testRfftSineTest(self):
204225 delta = 1 )
205226 fft_length = 2 * fft_length
206227
228+ def testRfft (self ):
229+ self .SingleRfftTest ('testdata/rfft_test1.txt' )
230+
231+ def testRfftLargeOuterDimension (self ):
232+ self .MultiDimRfftTest ('testdata/rfft_test1.txt' )
233+
207234 def testFftTooLarge (self ):
208235 for dtype in [np .int16 , np .int32 , np .float32 ]:
209236 fft_input = np .zeros (round (fft_ops ._MAX_FFT_LENGTH * 2 ), dtype = dtype )
@@ -224,6 +251,9 @@ def testFftLengthNoEven(self):
224251 with self .assertRaises ((tf .errors .InvalidArgumentError , ValueError )):
225252 self .evaluate (fft_ops .rfft (fft_input , 127 ))
226253
254+ def testAutoScale (self ):
255+ self .SingleFftAutoScaleTest ('testdata/fft_auto_scale_test1.txt' )
256+
227257 def testPow2FftLengthTest (self ):
228258 fft_length , fft_bits = fft_ops .get_pow2_fft_length (131 )
229259 self .assertEqual (fft_length , 256 )
0 commit comments