Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tflite_micro/python_ops_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ PythonOpsResolver::PythonOpsResolver() {
AddGreaterEqual();
AddHardSwish();
AddIf();
AddIrfft();
AddL2Normalization();
AddL2Pool2D();
AddLeakyRelu();
Expand Down
2 changes: 2 additions & 0 deletions python/tflite_micro/signal/ops/fft_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def _fft_auto_scale(input_tensor, name=default_name):


rfft = _fft_wrapper(gen_fft_ops.signal_rfft, "signal_rfft")
irfft = _fft_wrapper(gen_fft_ops.signal_irfft, "signal_irfft")
fft_auto_scale = _fft_auto_scale_wrapper(gen_fft_ops.signal_fft_auto_scale,
"signal_fft_auto_scale")
tf.no_gradient("signal_rfft")
tf.no_gradient("signal_irfft")
tf.no_gradient("signal_fft_auto_scale")
55 changes: 55 additions & 0 deletions python/tflite_micro/signal/ops/fft_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,61 @@ def testFftLengthNoEven(self):
with self.assertRaises((tf.errors.InvalidArgumentError, ValueError)):
self.evaluate(fft_ops.rfft(fft_input, 127))

def testIrfftTest(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random(fft_length).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, fft_length).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[0], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[0], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 6500)
elif dtype == np.int32:
self.assertArrayNear(fft_input,
ifft_output.astype(np.int32) * fft_length, 7875)
else:
self.assertArrayNear(fft_input, ifft_output, 5e-7)
fft_length = 2 * fft_length

def testIrfftLargeOuterDimension(self):
for dtype in [np.int16, np.int32, np.float32]:
fft_length = fft_ops._MIN_FFT_LENGTH
while fft_length <= fft_ops._MAX_FFT_LENGTH:
if dtype == np.float32:
# Random input in the range [-1, 1)
fft_input = np.random.random([2, 5, fft_length
]).astype(dtype) * 2 - 1
else:
fft_input = np.random.randint(
np.iinfo(np.int16).min,
np.iinfo(np.int16).max + 1, [2, 5, fft_length]).astype(dtype)
fft_output = self.evaluate(fft_ops.rfft(fft_input, fft_length))
self.assertEqual(fft_output.shape[-1], (fft_length / 2 + 1) * 2)
ifft_output = self.evaluate(fft_ops.irfft(fft_output, fft_length))
self.assertEqual(ifft_output.shape[-1], fft_length)
# Output of integer RFFT and IRFFT is scaled by 1/fft_length
if dtype == np.int16:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
elif dtype == np.int32:
self.assertAllClose(fft_input,
ifft_output.astype(np.int32) * fft_length,
atol=7875)
else:
self.assertAllClose(fft_input, ifft_output, rtol=5e-7, atol=5e-7)
fft_length = 2 * fft_length

def testAutoScale(self):
self.SingleFftAutoScaleTest('testdata/fft_auto_scale_test1.txt')

Expand Down
3 changes: 3 additions & 0 deletions signal/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ cc_library(
"filter_bank_spectral_subtraction.cc",
"filter_bank_square_root.cc",
"framer.cc",
"irfft.cc",
"overlap_add.cc",
"rfft.cc",
"stacker.cc",
"window.cc",
],
hdrs = [
"irfft.h",
"rfft.h",
],
copts = micro_copts(),
Expand All @@ -36,6 +38,7 @@ cc_library(
"//signal/src:filter_bank_log",
"//signal/src:filter_bank_spectral_subtraction",
"//signal/src:filter_bank_square_root",
"//signal/src:irfft",
"//signal/src:overlap_add",
"//signal/src:rfft",
"//signal/src:window",
Expand Down
158 changes: 158 additions & 0 deletions signal/micro/kernels/fft_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,164 @@ TF_LITE_MICRO_TEST(RfftTestSize512Int32) {
g_gen_data_size_fft_length_512_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Float) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const float input[] = {256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const float golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_float,
g_gen_data_size_fft_length_64_int16, output, 1e-7));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int16) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int16_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int16_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int16_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int16,
g_gen_data_size_fft_length_64_int16, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int32) {
constexpr int kOutputLen = 64;
int input_shape[] = {1, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {1, kOutputLen};
const int32_t golden[] = {256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength64Int32OuterDims4) {
constexpr int kOutputLen = 64;
constexpr int kOuterDim = 2;
int input_shape[] = {3, kOuterDim, kOuterDim, 66};
const int32_t input[] = {
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0,
256, 0, 256, 0, 256, 0, 256, 0, 256, 0, 256, 0};
int output_shape[] = {3, kOuterDim, kOuterDim, kOutputLen};
const int32_t golden[] = {
256, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 256, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int32_t output[kOuterDim * kOuterDim * kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<int32_t>(
input_shape, input, output_shape, golden, *registration,
g_gen_data_fft_length_64_int32,
g_gen_data_size_fft_length_64_int32, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Float) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
float output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_FLOAT();
TF_LITE_MICRO_EXPECT_EQ(
kTfLiteOk, tflite::testing::TestFFT<float>(
input_shape, tflite::kIrfftFloatLength512Input,
output_shape, tflite::kIrfftFloatLength512Golden,
*registration, g_gen_data_fft_length_512_float,
g_gen_data_size_fft_length_512_float, output, 1e-7));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Int16) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int16_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT16();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int16_t>(
input_shape, tflite::kIrfftInt16Length512Input,
output_shape, tflite::kIrfftInt16Length512Golden,
*registration, g_gen_data_fft_length_512_int16,
g_gen_data_size_fft_length_512_int16, output, 0));
}

TF_LITE_MICRO_TEST(IrfftTestLength512Int32) {
constexpr int kOutputLen = 512;
int input_shape[] = {1, 514};
int output_shape[] = {1, kOutputLen};
int32_t output[kOutputLen];
const TFLMRegistration* registration =
tflite::tflm_signal::Register_IRFFT_INT32();
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk,
tflite::testing::TestFFT<int32_t>(
input_shape, tflite::kIrfftInt32Length512Input,
output_shape, tflite::kIrfftInt32Length512Golden,
*registration, g_gen_data_fft_length_512_int32,
g_gen_data_size_fft_length_512_int32, output, 0));
}

TF_LITE_MICRO_TEST(FftAutoScaleTestSmall) {
constexpr int kTensorsSize = 8;
int shape[] = {1, 8};
Expand Down
Loading