Skip to content

Commit 55037d2

Browse files
authored
Adds IRFFT Op to Signal Library (#2137)
Inverse-RFFT as part of Signal library ops. Testing via current FFT Op tests. BUG=[287346710](http://b/287346710)
1 parent ed11500 commit 55037d2

File tree

20 files changed

+1310
-2
lines changed

20 files changed

+1310
-2
lines changed

python/tflite_micro/python_ops_resolver.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ PythonOpsResolver::PythonOpsResolver() {
6464
AddGreaterEqual();
6565
AddHardSwish();
6666
AddIf();
67+
AddIrfft();
6768
AddL2Normalization();
6869
AddL2Pool2D();
6970
AddLeakyRelu();

python/tflite_micro/signal/ops/fft_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def _fft_auto_scale(input_tensor, name=default_name):
8080

8181

8282
rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
83+
irfft = _fft_wrapper(gen_fft_ops.signal_irfft, "signal_irfft")
8384
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
8485
"signal_fft_auto_scale")
8586
tf.no_gradient("signal_rfft")
87+
tf.no_gradient("signal_irfft")
8688
tf.no_gradient("signal_fft_auto_scale")

python/tflite_micro/signal/ops/fft_ops_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,61 @@ def testFftLengthNoEven(self):
251251
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
252252
self.evaluate(fft_ops.rfft(fft_input, 127))
253253

254+
def testIrfftTest(self):
255+
for dtype in [np.int16, np.int32, np.float32]:
256+
fft_length = fft_ops._MIN_FFT_LENGTH
257+
while fft_length <= fft_ops._MAX_FFT_LENGTH:
258+
if dtype == np.float32:
259+
# Random input in the range [-1, 1)
260+
fft_input = np.random.random(fft_length).astype(dtype) * 2 - 1
261+
else:
262+
fft_input = np.random.randint(
263+
np.iinfo(np.int16).min,
264+
np.iinfo(np.int16).max + 1, fft_length).astype(dtype)
265+
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
266+
self.assertEqual(fft_output.shape[0], (fft_length / 2 + 1) * 2)
267+
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
268+
self.assertEqual(ifft_output.shape[0], fft_length)
269+
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
270+
if dtype == np.int16:
271+
self.assertArrayNear(fft_input,
272+
ifft_output.astype(np.int32) * fft_length, 6500)
273+
elif dtype == np.int32:
274+
self.assertArrayNear(fft_input,
275+
ifft_output.astype(np.int32) * fft_length, 7875)
276+
else:
277+
self.assertArrayNear(fft_input, ifft_output, 5e-7)
278+
fft_length = 2 * fft_length
279+
280+
def testIrfftLargeOuterDimension(self):
281+
for dtype in [np.int16, np.int32, np.float32]:
282+
fft_length = fft_ops._MIN_FFT_LENGTH
283+
while fft_length <= fft_ops._MAX_FFT_LENGTH:
284+
if dtype == np.float32:
285+
# Random input in the range [-1, 1)
286+
fft_input = np.random.random([2, 5, fft_length
287+
]).astype(dtype) * 2 - 1
288+
else:
289+
fft_input = np.random.randint(
290+
np.iinfo(np.int16).min,
291+
np.iinfo(np.int16).max + 1, [2, 5, fft_length]).astype(dtype)
292+
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
293+
self.assertEqual(fft_output.shape[-1], (fft_length / 2 + 1) * 2)
294+
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
295+
self.assertEqual(ifft_output.shape[-1], fft_length)
296+
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
297+
if dtype == np.int16:
298+
self.assertAllClose(fft_input,
299+
ifft_output.astype(np.int32) * fft_length,
300+
atol=7875)
301+
elif dtype == np.int32:
302+
self.assertAllClose(fft_input,
303+
ifft_output.astype(np.int32) * fft_length,
304+
atol=7875)
305+
else:
306+
self.assertAllClose(fft_input, ifft_output, rtol=5e-7, atol=5e-7)
307+
fft_length = 2 * fft_length
308+
254309
def testAutoScale(self):
255310
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')
256311

signal/micro/kernels/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ cc_library(
1616
"filter_bank_spectral_subtraction.cc",
1717
"filter_bank_square_root.cc",
1818
"framer.cc",
19+
"irfft.cc",
1920
"overlap_add.cc",
2021
"rfft.cc",
2122
"stacker.cc",
2223
"window.cc",
2324
],
2425
hdrs = [
26+
"irfft.h",
2527
"rfft.h",
2628
],
2729
copts = micro_copts(),
@@ -36,6 +38,7 @@ cc_library(
3638
"//signal/src:filter_bank_log",
3739
"//signal/src:filter_bank_spectral_subtraction",
3840
"//signal/src:filter_bank_square_root",
41+
"//signal/src:irfft",
3942
"//signal/src:overlap_add",
4043
"//signal/src:rfft",
4144
"//signal/src:window",

signal/micro/kernels/fft_test.cc

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,164 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
303303
g_gen_data_size_fft_length_512_int32, output, 0));
304304
}
305305

306+
TF_LITE_MICRO_TEST(IrfftTestLength64Float) {
307+
constexpr int kOutputLen = 64;
308+
int input_shape[] = {1, 66};
309+
const float input[] = {256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
310+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
311+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
312+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
313+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
314+
int output_shape[] = {1, kOutputLen};
315+
const float golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
316+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
317+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
318+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
319+
float output[kOutputLen];
320+
const TFLMRegistration* registration =
321+
tflite::tflm_signal::Register_IRFFT_FLOAT();
322+
TF_LITE_MICRO_EXPECT_EQ(
323+
kTfLiteOk, tflite::testing::TestFFT<float>(
324+
input_shape, input, output_shape, golden, *registration,
325+
g_gen_data_fft_length_64_float,
326+
g_gen_data_size_fft_length_64_int16, output, 1e-7));
327+
}
328+
329+
TF_LITE_MICRO_TEST(IrfftTestLength64Int16) {
330+
constexpr int kOutputLen = 64;
331+
int input_shape[] = {1, 66};
332+
const int16_t input[] = {
333+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
334+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
335+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
336+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
337+
int output_shape[] = {1, kOutputLen};
338+
const int16_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
339+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
340+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
341+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
342+
int16_t output[kOutputLen];
343+
const TFLMRegistration* registration =
344+
tflite::tflm_signal::Register_IRFFT_INT16();
345+
TF_LITE_MICRO_EXPECT_EQ(
346+
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
347+
input_shape, input, output_shape, golden, *registration,
348+
g_gen_data_fft_length_64_int16,
349+
g_gen_data_size_fft_length_64_int16, output, 0));
350+
}
351+
352+
TF_LITE_MICRO_TEST(IrfftTestLength64Int32) {
353+
constexpr int kOutputLen = 64;
354+
int input_shape[] = {1, 66};
355+
const int32_t input[] = {
356+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
357+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
358+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
359+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
360+
int output_shape[] = {1, kOutputLen};
361+
const int32_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
362+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
363+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
364+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
365+
int32_t output[kOutputLen];
366+
const TFLMRegistration* registration =
367+
tflite::tflm_signal::Register_IRFFT_INT32();
368+
TF_LITE_MICRO_EXPECT_EQ(
369+
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
370+
input_shape, input, output_shape, golden, *registration,
371+
g_gen_data_fft_length_64_int32,
372+
g_gen_data_size_fft_length_64_int32, output, 0));
373+
}
374+
375+
TF_LITE_MICRO_TEST(IrfftTestLength64Int32OuterDims4) {
376+
constexpr int kOutputLen = 64;
377+
constexpr int kOuterDim = 2;
378+
int input_shape[] = {3, kOuterDim, kOuterDim, 66};
379+
const int32_t input[] = {
380+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
381+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
382+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
383+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
384+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
385+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
386+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
387+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
388+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
389+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
390+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
391+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
392+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
393+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
394+
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
395+
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
396+
const int32_t golden[] = {
397+
256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
398+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
399+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0,
400+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
401+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
402+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0,
403+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
404+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
405+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0, 0, 0,
406+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
407+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
408+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
409+
int32_t output[kOuterDim * kOuterDim * kOutputLen];
410+
const TFLMRegistration* registration =
411+
tflite::tflm_signal::Register_IRFFT_INT32();
412+
TF_LITE_MICRO_EXPECT_EQ(
413+
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
414+
input_shape, input, output_shape, golden, *registration,
415+
g_gen_data_fft_length_64_int32,
416+
g_gen_data_size_fft_length_64_int32, output, 0));
417+
}
418+
419+
TF_LITE_MICRO_TEST(IrfftTestLength512Float) {
420+
constexpr int kOutputLen = 512;
421+
int input_shape[] = {1, 514};
422+
int output_shape[] = {1, kOutputLen};
423+
float output[kOutputLen];
424+
const TFLMRegistration* registration =
425+
tflite::tflm_signal::Register_IRFFT_FLOAT();
426+
TF_LITE_MICRO_EXPECT_EQ(
427+
kTfLiteOk, tflite::testing::TestFFT<float>(
428+
input_shape, tflite::kIrfftFloatLength512Input,
429+
output_shape, tflite::kIrfftFloatLength512Golden,
430+
*registration, g_gen_data_fft_length_512_float,
431+
g_gen_data_size_fft_length_512_float, output, 1e-7));
432+
}
433+
434+
TF_LITE_MICRO_TEST(IrfftTestLength512Int16) {
435+
constexpr int kOutputLen = 512;
436+
int input_shape[] = {1, 514};
437+
int output_shape[] = {1, kOutputLen};
438+
int16_t output[kOutputLen];
439+
const TFLMRegistration* registration =
440+
tflite::tflm_signal::Register_IRFFT_INT16();
441+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
442+
tflite::testing::TestFFT<int16_t>(
443+
input_shape, tflite::kIrfftInt16Length512Input,
444+
output_shape, tflite::kIrfftInt16Length512Golden,
445+
*registration, g_gen_data_fft_length_512_int16,
446+
g_gen_data_size_fft_length_512_int16, output, 0));
447+
}
448+
449+
TF_LITE_MICRO_TEST(IrfftTestLength512Int32) {
450+
constexpr int kOutputLen = 512;
451+
int input_shape[] = {1, 514};
452+
int output_shape[] = {1, kOutputLen};
453+
int32_t output[kOutputLen];
454+
const TFLMRegistration* registration =
455+
tflite::tflm_signal::Register_IRFFT_INT32();
456+
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
457+
tflite::testing::TestFFT<int32_t>(
458+
input_shape, tflite::kIrfftInt32Length512Input,
459+
output_shape, tflite::kIrfftInt32Length512Golden,
460+
*registration, g_gen_data_fft_length_512_int32,
461+
g_gen_data_size_fft_length_512_int32, output, 0));
462+
}
463+
306464
TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
307465
constexpr int kTensorsSize = 8;
308466
int shape[] = {1, 8};

0 commit comments

Comments
 (0)