Skip to content

Commit 70aed11

Browse files
authored
Adds FFT Auto Scale Op (#2134)
This PR adds additional FFT op functionality in the Signal library, namely adding the FFT Auto Scale operation. Testing added in the original `fft_test.cc` and `fft_ops_test.py`. BUG=[287346710](http://b/287346710)
1 parent 52007f6 commit 70aed11

File tree

22 files changed

+1107
-8
lines changed

22 files changed

+1107
-8
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ PythonOpsResolver::PythonOpsResolver() {
5151
AddEthosU();
5252
AddExp();
5353
AddExpandDims();
54+
AddFftAutoScale();
5455
AddFill();
5556
AddFloor();
5657
AddFloorDiv();

python/tflite_micro/signal/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ py_tflm_signal_library(
4646
py_test(
4747
name = "fft_ops_test",
4848
srcs = ["ops/fft_ops_test.py"],
49+
data = [
50+
"//python/tflite_micro/signal/ops/testdata:fft_auto_scale_test1.txt",
51+
"//python/tflite_micro/signal/ops/testdata:rfft_test1.txt",
52+
],
4953
python_version = "PY3",
5054
srcs_version = "PY3",
5155
deps = [

python/tflite_micro/signal/ops/fft_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,22 @@ def _fft(input_tensor, fft_length, name=default_name):
6565
return _fft
6666

6767

68+
def _fft_auto_scale_wrapper(fft_auto_scale_fn, default_name):
69+
"""Wrapper around gen_fft_ops.fft_auto_scale*."""
70+
71+
def _fft_auto_scale(input_tensor, name=default_name):
72+
with tf.name_scope(name) as name:
73+
input_tensor = tf.convert_to_tensor(input_tensor, dtype=tf.int16)
74+
dim_list = input_tensor.shape.as_list()
75+
if len(dim_list) != 1:
76+
raise ValueError("Input tensor must have a rank of 1")
77+
return fft_auto_scale_fn(input_tensor, name=name)
78+
79+
return _fft_auto_scale
80+
81+
6882
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
83+
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
84+
"signal_fft_auto_scale")
6985
tf.no_gradient("signal_rfft")
86+
tf.no_gradient("signal_fft_auto_scale")

python/tflite_micro/signal/ops/fft_ops_test.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,31 @@ def GetResource(self, filepath):
3333
file_text = f.read()
3434
return file_text
3535

36+
def SingleFftAutoScaleTest(self, filename):
37+
lines = self.GetResource(filename).splitlines()
38+
func = tf.function(fft_ops.fft_auto_scale)
39+
input_size = len(lines[0].split())
40+
concrete_function = func.get_concrete_function(
41+
tf.TensorSpec(input_size, dtype=tf.int16))
42+
interpreter = util.get_tflm_interpreter(concrete_function, func)
43+
i = 0
44+
while i < len(lines):
45+
in_frame = np.array([int(j) for j in lines[i].split()], dtype=np.int16)
46+
out_frame_exp = [int(j) for j in lines[i + 1].split()]
47+
scale_exp = [int(j) for j in lines[i + 2].split()]
48+
# TFLM
49+
interpreter.set_input(in_frame, 0)
50+
interpreter.invoke()
51+
out_frame = interpreter.get_output(0)
52+
scale = interpreter.get_output(1)
53+
self.assertAllEqual(out_frame_exp, out_frame)
54+
self.assertEqual(scale_exp, scale)
55+
# TF
56+
out_frame, scale = self.evaluate(fft_ops.fft_auto_scale(in_frame))
57+
self.assertAllEqual(out_frame_exp, out_frame)
58+
self.assertEqual(scale_exp, scale)
59+
i += 3
60+
3661
def SingleRfftTest(self, filename):
3762
lines = self.GetResource(filename).splitlines()
3863
args = lines[0].split()
@@ -43,8 +68,6 @@ def SingleRfftTest(self, filename):
4368
tf.TensorSpec(input_size, dtype=tf.int16), fft_length)
4469
# TODO(b/286252893): make test more robust (vs scipy)
4570
interpreter = util.get_tflm_interpreter(concrete_function, func)
46-
input_details = interpreter.get_input_details()
47-
output_details = interpreter.get_output_details()
4871
# Skip line 0, which contains the configuration params.
4972
# Read lines in pairs <input, expected>
5073
i = 1
@@ -53,9 +76,9 @@ def SingleRfftTest(self, filename):
5376
out_frame_exp = [int(j) for j in lines[i + 1].split()]
5477
# Compare TFLM inference against the expected golden values
5578
# TODO(b/286252893): validate usage of testing vs interpreter here
56-
interpreter.set_tensor(input_details[0]['index'], in_frame)
79+
interpreter.set_input(in_frame, 0)
5780
interpreter.invoke()
58-
out_frame = interpreter.get_tensor(output_details[0]['index'])
81+
out_frame = interpreter.get_output(0)
5982
self.assertAllEqual(out_frame_exp, out_frame)
6083
# TF
6184
out_frame = self.evaluate(fft_ops.rfft(in_frame, fft_length))
@@ -83,11 +106,9 @@ def MultiDimRfftTest(self, filename):
83106
concrete_function = func.get_concrete_function(
84107
tf.TensorSpec(np.shape(in_frames), dtype=tf.int16), fft_length)
85108
interpreter = util.get_tflm_interpreter(concrete_function, func)
86-
input_details = interpreter.get_input_details()
87-
output_details = interpreter.get_output_details()
88-
interpreter.set_tensor(input_details[0]['index'], in_frames)
109+
interpreter.set_input(in_frames, 0)
89110
interpreter.invoke()
90-
out_frame = interpreter.get_tensor(output_details[0]['index'])
111+
out_frame = interpreter.get_output(0)
91112
self.assertAllEqual(out_frames_exp, out_frame)
92113
# TF
93114
out_frames = self.evaluate(fft_ops.rfft(in_frames, fft_length))
@@ -204,6 +225,12 @@ def testRfftSineTest(self):
204225
delta=1)
205226
fft_length = 2 * fft_length
206227

228+
def testRfft(self):
229+
self.SingleRfftTest('testdata/rfft_test1.txt')
230+
231+
def testRfftLargeOuterDimension(self):
232+
self.MultiDimRfftTest('testdata/rfft_test1.txt')
233+
207234
def testFftTooLarge(self):
208235
for dtype in [np.int16, np.int32, np.float32]:
209236
fft_input = np.zeros(round(fft_ops._MAX_FFT_LENGTH * 2), dtype=dtype)
@@ -224,6 +251,9 @@ def testFftLengthNoEven(self):
224251
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
225252
self.evaluate(fft_ops.rfft(fft_input, 127))
226253

254+
def testAutoScale(self):
255+
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')
256+
227257
def testPow2FftLengthTest(self):
228258
fft_length, fft_bits = fft_ops.get_pow2_fft_length(131)
229259
self.assertEqual(fft_length, 256)

python/tflite_micro/signal/ops/testdata/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package(
77
)
88

99
exports_files([
10+
"fft_auto_scale_test1.txt",
1011
"rfft_test1.txt",
1112
"window_test1.txt",
1213
])

0 commit comments

Comments
 (0)