From fa5fd01612d454a4c374e6bc4fd7cdf999e1c9c4 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Mon, 16 Jun 2025 04:52:39 +0000 Subject: [PATCH 01/15] Sync files related to Reverse_V2 from TFLite #3110 --- ci/tflite_files.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/tflite_files.txt b/ci/tflite_files.txt index 83bd58fc0a6..a95475ec9e0 100644 --- a/ci/tflite_files.txt +++ b/ci/tflite_files.txt @@ -88,6 +88,7 @@ tensorflow/lite/kernels/internal/reference/reduce.h tensorflow/lite/kernels/internal/reference/requantize.h tensorflow/lite/kernels/internal/reference/resize_bilinear.h tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h +tensorflow/lite/kernels/internal/reference/reverse.h tensorflow/lite/kernels/internal/reference/round.h tensorflow/lite/kernels/internal/reference/softmax.h tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h From 780c702c7e8a6dfb2bb98e6f8586152766ecbc3a Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Thu, 4 Sep 2025 06:56:31 +0000 Subject: [PATCH 02/15] PRelu Int16x8 support in RefC --- tensorflow/lite/micro/kernels/prelu.cc | 16 +++- tensorflow/lite/micro/kernels/prelu_test.cc | 87 +++++++++++++++------ 2 files changed, 75 insertions(+), 28 deletions(-) diff --git a/tensorflow/lite/micro/kernels/prelu.cc b/tensorflow/lite/micro/kernels/prelu.cc index 66a017b2aec..e98b4832c0c 100644 --- a/tensorflow/lite/micro/kernels/prelu.cc +++ b/tensorflow/lite/micro/kernels/prelu.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -61,9 +61,19 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { tflite::micro::GetTensorData(output)); return kTfLiteOk; } break; + case kTfLiteInt16: { + reference_ops::BroadcastPrelu4DSlow( + params, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(alpha), + tflite::micro::GetTensorData(alpha), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + return kTfLiteOk; + } break; default: - MicroPrintf("Only float32 and uint8_t are supported currently, got %d.", - TfLiteTypeGetName(input->type)); + MicroPrintf("Input type '%s' (%d) is not supported.", + TfLiteTypeGetName(input->type), input->type); return kTfLiteError; } } diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index e4060347faf..4ab2e321407 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,26 +23,21 @@ namespace tflite { namespace testing { namespace { -template -void ValidatePreluGoldens(TfLiteTensor* tensors, int tensors_size, - const T* golden, const int output_length, - T* output_data) { +const float kQuantizedTolerance = 2 * (1. / 256); + +void ExecutePReluTest(const int tensors_count, + TfLiteTensor* tensors) { int inputs_array_data[] = {2, 0, 1}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 2}; TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); const TFLMRegistration registration = tflite::Register_PRELU(); - micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array, - outputs_array, - /*builtin_data=*/nullptr); + micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array, + outputs_array, /*builtin_data=*/nullptr); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); - - for (int i = 0; i < output_length; ++i) { - TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], 1e-5f); - } } void TestPreluFloat(int* input_dims_data, const float* input_data, @@ -62,19 +57,23 @@ void TestPreluFloat(int* input_dims_data, const float* input_data, CreateTensor(output_data, output_dims), }; - ValidatePreluGoldens(tensors, tensors_size, expected_output_data, - output_dims_count, output_data); + ExecutePReluTest(tensors_size, tensors); + + for (int i = 0; i < output_dims_count; i++) { + TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]); + } } -template +template void TestPreluQuantized(int* input_dims_data, const float* input_data, T* input_quantized, const float input_scale, const int input_zero_point, int* alpha_dims_data, - const float* alpha_data, T* alpha_quantized, + const float* alpha_data, Slope* alpha_quantized, const float alpha_scale, const int alpha_zero_point, const float* golden, T* golden_quantized, const float output_scale, const int output_zero_point, - int* output_dims_data, T* output_data) { + int* output_dims_data, T* output_quantized, + float* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* alpha_dims = IntArrayFromInts(alpha_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); @@ -87,15 +86,18 @@ void TestPreluQuantized(int* input_dims_data, const float* input_data, input_scale, input_zero_point), CreateQuantizedTensor(alpha_data, alpha_quantized, alpha_dims, alpha_scale, alpha_zero_point), - CreateQuantizedTensor(output_data, output_dims, output_scale, + CreateQuantizedTensor(output_quantized, output_dims, output_scale, output_zero_point), }; - Quantize(golden, golden_quantized, output_dims_count, output_scale, - output_zero_point); + ExecutePReluTest(tensors_size, tensors); - ValidatePreluGoldens(tensors, tensors_size, golden_quantized, - output_dims_count, output_data); + Dequantize(output_quantized, output_dims_count, output_scale, output_zero_point, + output_data); + + for (int i = 0; i < output_dims_count; i++) { + TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], kQuantizedTolerance); + } } } // namespace } // namespace testing @@ -150,10 +152,45 @@ TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) { int8_t golden_quantized[dims_count]; float scale = 2.0 / 255.0; int zero_point = 0; - int8_t output_data[dims_count]; - tflite::testing::TestPreluQuantized( + int8_t output_data_q[dims_count]; + float output_data_f[dims_count]; + tflite::testing::TestPreluQuantized( input_shape, input_values, input_quantized, scale, zero_point, alpha_shape, alpha_values, alpha_quantized, scale, zero_point, golden, - golden_quantized, scale, zero_point, output_shape, output_data); + golden_quantized, scale, zero_point, output_shape, output_data_q, + output_data_f); +} + +TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) { + int input_shape[] = {3, 2, 2, 3}; + const float input_values[] = { + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 0.5f, 0.5f, 0.5f, // Row 1, Column 2 + -1.0f, -1.0f, -1.0f, // Row 2, Column 1 + -0.25f, -0.25f, -0.25f, // Row 1, Column 2 + }; + int alpha_shape[] = {3, 1, 1, 3}; + const float alpha_values[] = {0.0f, 0.5f, -0.5f}; + int output_shape[] = {3, 2, 2, 3}; + const float golden[] = { + 0.0f, 0.0f, 0.0f, // Row 1, Column 1 + 0.5f, 0.5f, 0.5f, // Row 1, Column 2 + 0.0f, -0.5f, 0.5f, // Row 2, Column 1 + 0.0f, -0.125f, 0.125f, // Row 1, Column 2 + }; + const int dims_count = 12; + int16_t input_quantized[dims_count]; + int8_t alpha_quantized[3]; + int16_t golden_quantized[dims_count]; + float scale_input_output = 2.0 / 65535.0; + float scale_alpha = 2.0 / 255.0; + int zero_point = 0; + int16_t output_data_q[dims_count]; + float output_data_f[dims_count]; + tflite::testing::TestPreluQuantized( + input_shape, input_values, input_quantized, scale_input_output, zero_point, + alpha_shape, alpha_values, alpha_quantized, scale_alpha, zero_point, golden, + golden_quantized, scale_input_output, zero_point, output_shape, + output_data_q, output_data_f); } TF_LITE_MICRO_TESTS_END From b9ea8b1f9389aadeaf2ed106d869ac24588c201e Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Thu, 4 Sep 2025 07:17:29 +0000 Subject: [PATCH 03/15] Fix code style in prelu_test.cc --- tensorflow/lite/micro/kernels/prelu_test.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index 4ab2e321407..31648d13f67 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -25,8 +25,7 @@ namespace { const float kQuantizedTolerance = 2 * (1. / 256); -void ExecutePReluTest(const int tensors_count, - TfLiteTensor* tensors) { +void ExecutePReluTest(const int tensors_count, TfLiteTensor* tensors) { int inputs_array_data[] = {2, 0, 1}; TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); int outputs_array_data[] = {1, 2}; @@ -92,9 +91,9 @@ void TestPreluQuantized(int* input_dims_data, const float* input_data, ExecutePReluTest(tensors_size, tensors); - Dequantize(output_quantized, output_dims_count, output_scale, output_zero_point, - output_data); - + Dequantize(output_quantized, output_dims_count, output_scale, + output_zero_point, output_data); + for (int i = 0; i < output_dims_count; i++) { TF_LITE_MICRO_EXPECT_NEAR(golden[i], output_data[i], kQuantizedTolerance); } @@ -188,9 +187,9 @@ TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) { int16_t output_data_q[dims_count]; float output_data_f[dims_count]; tflite::testing::TestPreluQuantized( - input_shape, input_values, input_quantized, scale_input_output, zero_point, - alpha_shape, alpha_values, alpha_quantized, scale_alpha, zero_point, golden, - golden_quantized, scale_input_output, zero_point, output_shape, - output_data_q, output_data_f); + input_shape, input_values, input_quantized, scale_input_output, + zero_point, alpha_shape, alpha_values, alpha_quantized, scale_alpha, + zero_point, golden, golden_quantized, scale_input_output, zero_point, + output_shape, output_data_q, output_data_f); } TF_LITE_MICRO_TESTS_END From 04faaa850732985154fd14629f403d8d463c1413 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Mon, 8 Sep 2025 04:57:00 +0000 Subject: [PATCH 04/15] 1. Reverted the copyright year --- tensorflow/lite/micro/kernels/prelu.cc | 6 +++--- tensorflow/lite/micro/kernels/prelu_common.cc | 5 +++++ tensorflow/lite/micro/kernels/prelu_test.cc | 14 ++++++-------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/micro/kernels/prelu.cc b/tensorflow/lite/micro/kernels/prelu.cc index e98b4832c0c..cc7a900c0de 100644 --- a/tensorflow/lite/micro/kernels/prelu.cc +++ b/tensorflow/lite/micro/kernels/prelu.cc @@ -1,4 +1,4 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,8 +72,8 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } break; default: - MicroPrintf("Input type '%s' (%d) is not supported.", - TfLiteTypeGetName(input->type), input->type); + MicroPrintf("Input type '%s' is not supported.", + TfLiteTypeGetName(input->type)); return kTfLiteError; } } diff --git a/tensorflow/lite/micro/kernels/prelu_common.cc b/tensorflow/lite/micro/kernels/prelu_common.cc index 1a89cadf9d1..343c8e4e916 100644 --- a/tensorflow/lite/micro/kernels/prelu_common.cc +++ b/tensorflow/lite/micro/kernels/prelu_common.cc @@ -96,6 +96,11 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, CalculatePreluParams(input, alpha, output, params)); + if (output->type == kTfLiteInt16) { + // Make sure alpha type is Int8 when Output is Int16 + TF_LITE_ENSURE(context, alpha->type == kTfLiteInt8); + } + micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(alpha); micro_context->DeallocateTempTfLiteTensor(output); diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index 31648d13f67..82b9c284d79 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -69,10 +69,9 @@ void TestPreluQuantized(int* input_dims_data, const float* input_data, const int input_zero_point, int* alpha_dims_data, const float* alpha_data, Slope* alpha_quantized, const float alpha_scale, const int alpha_zero_point, - const float* golden, T* golden_quantized, - const float output_scale, const int output_zero_point, - int* output_dims_data, T* output_quantized, - float* output_data) { + const float* golden, const float output_scale, + const int output_zero_point, int* output_dims_data, + T* output_quantized, float* output_data) { TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data); TfLiteIntArray* alpha_dims = IntArrayFromInts(alpha_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); @@ -180,7 +179,6 @@ TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) { const int dims_count = 12; int16_t input_quantized[dims_count]; int8_t alpha_quantized[3]; - int16_t golden_quantized[dims_count]; float scale_input_output = 2.0 / 65535.0; float scale_alpha = 2.0 / 255.0; int zero_point = 0; @@ -189,7 +187,7 @@ TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) { tflite::testing::TestPreluQuantized( input_shape, input_values, input_quantized, scale_input_output, zero_point, alpha_shape, alpha_values, alpha_quantized, scale_alpha, - zero_point, golden, golden_quantized, scale_input_output, zero_point, - output_shape, output_data_q, output_data_f); + zero_point, golden, scale_input_output, zero_point, output_shape, + output_data_q, output_data_f); } TF_LITE_MICRO_TESTS_END From f27ec75fa42e31bb76e632307aca324834f1181a Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Mon, 8 Sep 2025 05:18:57 +0000 Subject: [PATCH 05/15] Resolved compilation error for Int8x8 test case --- tensorflow/lite/micro/kernels/prelu_test.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/prelu_test.cc b/tensorflow/lite/micro/kernels/prelu_test.cc index 82b9c284d79..7c2babf624e 100644 --- a/tensorflow/lite/micro/kernels/prelu_test.cc +++ b/tensorflow/lite/micro/kernels/prelu_test.cc @@ -147,7 +147,6 @@ TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) { const int dims_count = 12; int8_t input_quantized[dims_count]; int8_t alpha_quantized[3]; - int8_t golden_quantized[dims_count]; float scale = 2.0 / 255.0; int zero_point = 0; int8_t output_data_q[dims_count]; @@ -155,8 +154,7 @@ TF_LITE_MICRO_TEST(QuantizedInt8PreluActivationsOpTest) { tflite::testing::TestPreluQuantized( input_shape, input_values, input_quantized, scale, zero_point, alpha_shape, alpha_values, alpha_quantized, scale, zero_point, golden, - golden_quantized, scale, zero_point, output_shape, output_data_q, - output_data_f); + scale, zero_point, output_shape, output_data_q, output_data_f); } TF_LITE_MICRO_TEST(QuantizedInt16PreluActivationsOpTest) { From 29e40f6014e18f40caab82b5be603f6c099fd560 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 28 Oct 2025 10:33:56 +0000 Subject: [PATCH 06/15] Add Dynamic_Update_Slice support to TFLM --- python/tflite_micro/python_ops_resolver.cc | 1 + tensorflow/lite/micro/kernels/BUILD | 17 ++ .../micro/kernels/dynamic_update_slice.cc | 240 ++++++++++++++++++ .../lite/micro/kernels/dynamic_update_slice.h | 38 +++ .../kernels/dynamic_update_slice_test.cc | 146 +++++++++++ tensorflow/lite/micro/kernels/micro_ops.h | 1 + .../lite/micro/micro_mutable_op_resolver.h | 6 + .../micro/tools/benchmarking/op_resolver.h | 3 +- tensorflow/lite/micro/tools/make/Makefile | 1 + 9 files changed, 452 insertions(+), 1 deletion(-) create mode 100644 tensorflow/lite/micro/kernels/dynamic_update_slice.cc create mode 100644 tensorflow/lite/micro/kernels/dynamic_update_slice.h create mode 100644 tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc diff --git a/python/tflite_micro/python_ops_resolver.cc b/python/tflite_micro/python_ops_resolver.cc index 77ef336d9de..2b8f1a4a0ae 100644 --- a/python/tflite_micro/python_ops_resolver.cc +++ b/python/tflite_micro/python_ops_resolver.cc @@ -46,6 +46,7 @@ PythonOpsResolver::PythonOpsResolver() { AddDequantize(); AddDetectionPostprocess(); AddDiv(); + AddDynamicUpdateSlice(); AddElu(); AddEmbeddingLookup(); AddEnergy(); diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index f1d12f04634..9c1fa46e842 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -243,6 +243,7 @@ tflm_kernel_cc_library( "dequantize_common.cc", "detection_postprocess.cc", "div.cc", + "dynamic_update_slice.cc", "elementwise.cc", "elu.cc", "embedding_lookup.cc", @@ -329,6 +330,7 @@ tflm_kernel_cc_library( "conv.h", "depthwise_conv.h", "dequantize.h", + "dynamic_update_slice.h", "ethosu.h", "fully_connected.h", "hard_swish.h", @@ -737,6 +739,21 @@ tflm_cc_test( ], ) +tflm_cc_test( + name = "dynamic_update_slice_test", + srcs = [ + "dynamic_update_slice_test.cc", + ], + deps = [ + ":kernel_runner", + "//tensorflow/lite/c:common", + "//tensorflow/lite/micro:debug_log", + "//tensorflow/lite/micro:op_resolvers", + "//tensorflow/lite/micro:test_helpers", + "//tensorflow/lite/micro/testing:micro_test", + ], +) + tflm_cc_test( name = "elementwise_test", srcs = ["elementwise_test.cc"], diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc new file mode 100644 index 00000000000..768467b3c45 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -0,0 +1,240 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/micro/kernels/dynamic_update_slice.h" + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" +#include "tensorflow/lite/micro/micro_log.h" +#include "tensorflow/lite/micro/micro_utils.h" + +namespace tflite { + +constexpr int kMaxDimensions = 6; + +namespace { + +TfLiteStatus CalculateClampedStartIndices( + int num_dims, const int64_t* raw_indices_data, + const int32_t* input_dims_data, const int32_t* update_dims_data, + int32_t* clamped_start_indices_output) { + for (int i = 0; i < num_dims; ++i) { + clamped_start_indices_output[i] = static_cast( + std::min(std::max(0, raw_indices_data[i]), + input_dims_data[i] - update_dims_data[i])); + } + return kTfLiteOk; +} + +// Recursive helper for N-dimensional slice update. +template +TfLiteStatus UpdateSliceRecursive( + int current_dim, int max_dims, const int32_t* output_strides, + const int32_t* update_strides, const int32_t* update_dims_data, + const T* update_tensor_data, const int32_t* clamped_start_indices, + T* output_tensor_data) { + if (current_dim == max_dims) { + return kTfLiteOk; + } + + output_tensor_data += + clamped_start_indices[current_dim] * output_strides[current_dim]; + + if (current_dim == max_dims - 1) { + std::memcpy(output_tensor_data, update_tensor_data, + update_dims_data[max_dims - 1] * sizeof(T)); + } else { + for (int i = 0; i < update_dims_data[current_dim]; ++i) { + UpdateSliceRecursive( + current_dim + 1, max_dims, output_strides, update_strides, + update_dims_data, update_tensor_data, clamped_start_indices, + output_tensor_data); + + output_tensor_data += output_strides[current_dim]; + update_tensor_data += update_strides[current_dim]; + } + } + return kTfLiteOk; +} + +// Main dispatch function for Eval, templated on data type. +template +TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, + const TfLiteEvalTensor* update_eval, + const int64_t* indices_eval, + TfLiteEvalTensor* output_eval) { + const RuntimeShape operand_shape = + tflite::micro::GetTensorShape(operand_eval); + const RuntimeShape update_shape = tflite::micro::GetTensorShape(update_eval); + const T* update_tensor_data = tflite::micro::GetTensorData(update_eval); + T* output_tensor_data = tflite::micro::GetTensorData(output_eval); + + const int num_dims = operand_shape.DimensionsCount(); + if (operand_shape.FlatSize() == update_shape.FlatSize()) { + std::memcpy(output_tensor_data, update_tensor_data, + ElementCount(*operand_eval->dims) * sizeof(T)); + return kTfLiteOk; + } + + if (num_dims > kMaxDimensions) { + MicroPrintf( + "DYNAMIC_UPDATE_SLICE: Operand rank %d exceeds max supported %d.", + num_dims, kMaxDimensions); + return kTfLiteError; + } + + if (operand_eval->data.data != output_eval->data.data) { + std::memcpy(output_eval->data.data, operand_eval->data.data, + ElementCount(*operand_eval->dims) * sizeof(T)); + } + + // If update tensor is empty, no actual update is needed after operand copy. + if (ElementCount(*update_eval->dims) == 0) { + return kTfLiteOk; + } + + // Calculate clamped start indices (stack-allocated) + int32_t clamped_start_indices[kMaxDimensions]; + TF_LITE_ENSURE_STATUS(CalculateClampedStartIndices( + num_dims, indices_eval, operand_shape.DimsData(), + update_shape.DimsData(), clamped_start_indices)); + + // Calculate strides (stack-allocated) + int32_t output_stride[kMaxDimensions]; + int32_t update_stride[kMaxDimensions]; + output_stride[num_dims - 1] = 1; + update_stride[num_dims - 1] = 1; + for (int i = num_dims - 2; i >= 0; --i) { + output_stride[i] = output_stride[i + 1] * operand_shape.Dims(i + 1); + update_stride[i] = update_stride[i + 1] * update_shape.Dims(i + 1); + } + + // Perform the N-dimensional update + // The recursive function needs base pointers and initial offsets. + return UpdateSliceRecursive( + /*current_dim=*/0, num_dims, output_stride, update_stride, + update_shape.DimsData(), update_tensor_data, clamped_start_indices, + output_tensor_data); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + MicroContext* micro_context = GetMicroContext(context); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + // Use MicroContext to allocate temporary tensors for inspection + // This is a robust pattern shown in EMBEDDING_LOOKUP. + TfLiteTensor* operand = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceOperandTensor); + TF_LITE_ENSURE(context, operand != nullptr); + + TfLiteTensor* update = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceUpdateTensor); + TF_LITE_ENSURE(context, update != nullptr); + + TfLiteTensor* start_indices = micro_context->AllocateTempInputTensor( + node, kDynamicUpdateSliceStartIndicesTensor); + TF_LITE_ENSURE(context, start_indices != nullptr); + + TfLiteTensor* output = micro_context->AllocateTempOutputTensor( + node, kDynamicUpdateSliceOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + + // Type checks + TF_LITE_ENSURE_TYPES_EQ(context, operand->type, update->type); + TF_LITE_ENSURE(context, start_indices->type == kTfLiteInt32 || + start_indices->type == kTfLiteInt64); + + TF_LITE_ENSURE_EQ(context, NumDimensions(start_indices), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(start_indices, 0), + NumDimensions(operand)); + + TF_LITE_ENSURE_EQ(context, NumDimensions(update), NumDimensions(operand)); + // Check that update dimensions are not larger than operand dimensions + for (int i = 0; i < NumDimensions(operand); ++i) { + TF_LITE_ENSURE(context, + SizeOfDimension(update, i) <= SizeOfDimension(operand, i)); + } + output->type = operand->type; + + // Deallocate temporary tensors + micro_context->DeallocateTempTfLiteTensor(operand); + micro_context->DeallocateTempTfLiteTensor(update); + micro_context->DeallocateTempTfLiteTensor(start_indices); + micro_context->DeallocateTempTfLiteTensor( + output); // Output tensor metadata also temp + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteEvalTensor* operand_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceOperandTensor); + const TfLiteEvalTensor* update_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceUpdateTensor); + const TfLiteEvalTensor* indices_eval = tflite::micro::GetEvalInput( + context, node, kDynamicUpdateSliceStartIndicesTensor); + TfLiteEvalTensor* output_eval = tflite::micro::GetEvalOutput( + context, node, kDynamicUpdateSliceOutputTensor); + + const auto& input_shape = tflite::micro::GetTensorShape(operand_eval); + const int input_dims = input_shape.DimensionsCount(); + int64_t indices_data_i64[kMaxDimensions]; + if (indices_eval->type == kTfLiteInt32) { + for (int i = 0; i < input_dims; i++) + indices_data_i64[i] = static_cast(indices_eval->data.i32[i]); + } else if (indices_eval->type == kTfLiteInt64) { + for (int i = 0; i < input_dims; i++) + indices_data_i64[i] = indices_eval->data.i64[i]; + } else { + TF_LITE_KERNEL_LOG(context, + "DynamicUpdateSlice only currently supports " + "int32 or int64 indices type, got %d.", + indices_eval->type); + return kTfLiteError; + } + // Dispatch based on tensor type + switch (operand_eval->type) { + case kTfLiteFloat32: + return EvalImpl(operand_eval, update_eval, + indices_data_i64, output_eval); + case kTfLiteInt8: + return EvalImpl(operand_eval, update_eval, + indices_data_i64, output_eval); + case kTfLiteInt32: + return EvalImpl(operand_eval, update_eval, + indices_data_i64, output_eval); + default: + MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.", + TfLiteTypeGetName(operand_eval->type)); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace + +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE() { + return tflite::micro::RegisterOp(/*init=*/nullptr, /*prepare=*/Prepare, + /*invoke=*/Eval); +} + +} // namespace tflite + diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.h b/tensorflow/lite/micro/kernels/dynamic_update_slice.h new file mode 100644 index 00000000000..e78ac529da2 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.h @@ -0,0 +1,38 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/micro/micro_common.h" + +namespace tflite{ + +constexpr int kDynamicUpdateSliceOperandTensor = 0; +constexpr int kDynamicUpdateSliceUpdateTensor = 1; +constexpr int kDynamicUpdateSliceStartIndicesTensor = 2; +constexpr int kDynamicUpdateSliceOutputTensor = 0; + +TfLiteStatus PrepareDynamicUpdateSlice(TfLiteContext* context, + TfLiteNode* node); + +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE(); + +} // namespace tflite + + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ + diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc new file mode 100644 index 00000000000..3971f835484 --- /dev/null +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/micro/kernels/kernel_runner.h" +#include "tensorflow/lite/micro/micro_utils.h" +#include "tensorflow/lite/micro/test_helpers.h" +#include "tensorflow/lite/micro/testing/micro_test.h" + +namespace tflite { +namespace testing { +namespace { + +// constexpr float kTestTolerance = 7.41e-03; +constexpr int kNumInputs = 3; +constexpr int kNumOutputs = 1; +constexpr int kInputTensorIndex_0 = 0; +constexpr int kInputTensorIndex_1 = 1; +constexpr int kInputTensorIndex_2 = 2; +constexpr int kOutputTensorIndex = 3; + +// min/max are used to compute scale, zero-point is 0 +template +struct TestDynamicUpdateSliceParams { + // quantization parameters + float data_min; // input data minimum value + float data_max; // input data maximum value + int8_t input1_data[kInputSize]; // quantized input storage + int8_t input2_data[kInputSize]; // quantized input storage +}; + +void ExecuteDynamicUpdateSliceTest(TfLiteTensor* tensors, int tensors_count) { + int kInputArrayData[] = {kNumInputs, kInputTensorIndex_0, kInputTensorIndex_1, + kInputTensorIndex_2}; + TfLiteIntArray* inputs_array = IntArrayFromInts(kInputArrayData); + int kOutputArrayData[] = {kNumOutputs, kOutputTensorIndex}; + TfLiteIntArray* outputs_array = IntArrayFromInts(kOutputArrayData); + + const TFLMRegistration registration = tflite::Register_DYNAMIC_UPDATE_SLICE(); + micro::KernelRunner runner(registration, tensors, tensors_count, inputs_array, + outputs_array, nullptr); + + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare()); + TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); +} + +void TestDynamicUpdateSliceFloat(int* input_dims_data[kNumInputs], + const float* input_data_0, + const float* input_data_1, + const int32_t* input_data_2, + const float* golden_data, int* expected_dims, + float* output_data) { + TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]); + TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]); + TfLiteIntArray* input_dims_2 = IntArrayFromInts(input_dims_data[2]); + TfLiteIntArray* output_dims = IntArrayFromInts(expected_dims); + const int output_count = ElementCount(*output_dims); + + TfLiteTensor tensors[] = { + CreateTensor(input_data_0, input_dims_0), + CreateTensor(input_data_1, input_dims_1), + CreateTensor(input_data_2, input_dims_2), + CreateTensor(output_data, output_dims), + }; + constexpr int tensors_count = std::extent::value; + ExecuteDynamicUpdateSliceTest(tensors, tensors_count); + + // check output data against expected + for (int i = 0; i < output_count; i++) { + printf("output_data[%d] = %f\n", i, output_data[i]); + TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], 0.0); + } + + // check output dimensions (relocated) against original dimensions + TF_LITE_MICRO_EXPECT_EQ(output_dims->size, + tensors[kOutputTensorIndex].dims->size); + for (int i = 0; i < output_dims->size; i++) { + TF_LITE_MICRO_EXPECT_EQ(output_dims->data[i], + tensors[kOutputTensorIndex].dims->data[i]); + } +} + +// TODO(rameshkunasi): Add quantized test for dynamic update slice. + +} // namespace +} // namespace testing +} // namespace tflite + +TF_LITE_MICRO_TESTS_BEGIN + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 1}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr float kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr float kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSliceFloat(kInputDims, kInput_0, kInput_1, + kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 2}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr float kInput_1[] = {-1, -2, -3, -4}; + constexpr int32_t kInput_2[] = {2, 2}; + constexpr float kExpect[] = {1, 2, 3, 4, -1, -2, 7, -3, -4}; + constexpr int kOutputCount = std::extent::value; + float output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSliceFloat(kInputDims, kInput_0, kInput_1, + kInput_2, kExpect, kOutputDims, + output_data); +} +TF_LITE_MICRO_TESTS_END + diff --git a/tensorflow/lite/micro/kernels/micro_ops.h b/tensorflow/lite/micro/kernels/micro_ops.h index b3c9204b4d8..0762cac1891 100644 --- a/tensorflow/lite/micro/kernels/micro_ops.h +++ b/tensorflow/lite/micro/kernels/micro_ops.h @@ -57,6 +57,7 @@ TFLMRegistration Register_DEPTH_TO_SPACE(); TFLMRegistration Register_DEPTHWISE_CONV_2D(); TFLMRegistration Register_DEQUANTIZE(); TFLMRegistration Register_DIV(); +TFLMRegistration Register_DYNAMIC_UPDATE_SLICE(); TFLMRegistration Register_ELU(); TFLMRegistration Register_EMBEDDING_LOOKUP(); TFLMRegistration Register_EQUAL(); diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index ba94ac19482..4cd3460beab 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -246,6 +246,12 @@ class MicroMutableOpResolver : public MicroOpResolver { return AddBuiltin(BuiltinOperator_DIV, registration, ParseDiv); } + TfLiteStatus AddDynamicUpdateSlice() { + return AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE, + Register_DYNAMIC_UPDATE_SLICE(), + ParseDynamicUpdateSlice); + } + TfLiteStatus AddEmbeddingLookup( const TFLMRegistration& registration = Register_EMBEDDING_LOOKUP()) { return AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, registration, diff --git a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h index 7817eaed0e5..4ba845f81bf 100644 --- a/tensorflow/lite/micro/tools/benchmarking/op_resolver.h +++ b/tensorflow/lite/micro/tools/benchmarking/op_resolver.h @@ -23,7 +23,7 @@ limitations under the License. namespace tflite { -using TflmOpResolver = MicroMutableOpResolver<115>; +using TflmOpResolver = MicroMutableOpResolver<116>; inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddAbs()); @@ -51,6 +51,7 @@ inline TfLiteStatus CreateOpResolver(TflmOpResolver& op_resolver) { TF_LITE_ENSURE_STATUS(op_resolver.AddDequantize()); TF_LITE_ENSURE_STATUS(op_resolver.AddDetectionPostprocess()); TF_LITE_ENSURE_STATUS(op_resolver.AddDiv()); + TF_LITE_ENSURE_STATUS(op_resolver.AddDynamicUpdateSlice()); TF_LITE_ENSURE_STATUS(op_resolver.AddElu()); TF_LITE_ENSURE_STATUS(op_resolver.AddEmbeddingLookup()); TF_LITE_ENSURE_STATUS(op_resolver.AddEnergy()); diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index a21765b3454..bb9709f0262 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -395,6 +395,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_common.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/detection_postprocess.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dynamic_update_slice.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup.cc \ From ae96a48f129a4e2b8c185f8e23ceb6615ab9be9d Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 28 Oct 2025 12:30:54 +0000 Subject: [PATCH 07/15] Code style error correction --- .../micro/kernels/dynamic_update_slice.cc | 44 ++++++++++--------- .../lite/micro/kernels/dynamic_update_slice.h | 2 +- .../kernels/dynamic_update_slice_test.cc | 4 +- .../lite/micro/micro_mutable_op_resolver.h | 3 +- 4 files changed, 27 insertions(+), 26 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc index 768467b3c45..0d8224497d4 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -38,18 +38,20 @@ TfLiteStatus CalculateClampedStartIndices( for (int i = 0; i < num_dims; ++i) { clamped_start_indices_output[i] = static_cast( std::min(std::max(0, raw_indices_data[i]), - input_dims_data[i] - update_dims_data[i])); + input_dims_data[i] - update_dims_data[i])); } return kTfLiteOk; } // Recursive helper for N-dimensional slice update. template -TfLiteStatus UpdateSliceRecursive( - int current_dim, int max_dims, const int32_t* output_strides, - const int32_t* update_strides, const int32_t* update_dims_data, - const T* update_tensor_data, const int32_t* clamped_start_indices, - T* output_tensor_data) { +TfLiteStatus UpdateSliceRecursive(int current_dim, int max_dims, + const int32_t* output_strides, + const int32_t* update_strides, + const int32_t* update_dims_data, + const T* update_tensor_data, + const int32_t* clamped_start_indices, + T* output_tensor_data) { if (current_dim == max_dims) { return kTfLiteOk; } @@ -62,10 +64,10 @@ TfLiteStatus UpdateSliceRecursive( update_dims_data[max_dims - 1] * sizeof(T)); } else { for (int i = 0; i < update_dims_data[current_dim]; ++i) { - UpdateSliceRecursive( - current_dim + 1, max_dims, output_strides, update_strides, - update_dims_data, update_tensor_data, clamped_start_indices, - output_tensor_data); + UpdateSliceRecursive(current_dim + 1, max_dims, output_strides, + update_strides, update_dims_data, + update_tensor_data, clamped_start_indices, + output_tensor_data); output_tensor_data += output_strides[current_dim]; update_tensor_data += update_strides[current_dim]; @@ -113,8 +115,8 @@ TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, // Calculate clamped start indices (stack-allocated) int32_t clamped_start_indices[kMaxDimensions]; TF_LITE_ENSURE_STATUS(CalculateClampedStartIndices( - num_dims, indices_eval, operand_shape.DimsData(), - update_shape.DimsData(), clamped_start_indices)); + num_dims, indices_eval, operand_shape.DimsData(), update_shape.DimsData(), + clamped_start_indices)); // Calculate strides (stack-allocated) int32_t output_stride[kMaxDimensions]; @@ -205,22 +207,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { indices_data_i64[i] = indices_eval->data.i64[i]; } else { TF_LITE_KERNEL_LOG(context, - "DynamicUpdateSlice only currently supports " - "int32 or int64 indices type, got %d.", - indices_eval->type); + "DynamicUpdateSlice only currently supports " + "int32 or int64 indices type, got %d.", + indices_eval->type); return kTfLiteError; } // Dispatch based on tensor type switch (operand_eval->type) { case kTfLiteFloat32: - return EvalImpl(operand_eval, update_eval, - indices_data_i64, output_eval); + return EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); case kTfLiteInt8: - return EvalImpl(operand_eval, update_eval, - indices_data_i64, output_eval); + return EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); case kTfLiteInt32: - return EvalImpl(operand_eval, update_eval, - indices_data_i64, output_eval); + return EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); default: MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.", TfLiteTypeGetName(operand_eval->type)); diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.h b/tensorflow/lite/micro/kernels/dynamic_update_slice.h index e78ac529da2..3f68e97aa93 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.h +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/micro/micro_common.h" -namespace tflite{ +namespace tflite { constexpr int kDynamicUpdateSliceOperandTensor = 0; constexpr int kDynamicUpdateSliceUpdateTensor = 1; diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index 3971f835484..631d86951b0 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -38,8 +38,8 @@ constexpr int kOutputTensorIndex = 3; template struct TestDynamicUpdateSliceParams { // quantization parameters - float data_min; // input data minimum value - float data_max; // input data maximum value + float data_min; // input data minimum value + float data_max; // input data maximum value int8_t input1_data[kInputSize]; // quantized input storage int8_t input2_data[kInputSize]; // quantized input storage }; diff --git a/tensorflow/lite/micro/micro_mutable_op_resolver.h b/tensorflow/lite/micro/micro_mutable_op_resolver.h index 9665d7d7396..c5540ea669a 100644 --- a/tensorflow/lite/micro/micro_mutable_op_resolver.h +++ b/tensorflow/lite/micro/micro_mutable_op_resolver.h @@ -253,8 +253,7 @@ class MicroMutableOpResolver : public MicroOpResolver { TfLiteStatus AddDynamicUpdateSlice() { return AddBuiltin(BuiltinOperator_DYNAMIC_UPDATE_SLICE, - Register_DYNAMIC_UPDATE_SLICE(), - ParseDynamicUpdateSlice); + Register_DYNAMIC_UPDATE_SLICE(), ParseDynamicUpdateSlice); } TfLiteStatus AddEmbeddingLookup( From 665cf674b4298bb139b4bc8ba49a9068d3599cb7 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 28 Oct 2025 12:43:34 +0000 Subject: [PATCH 08/15] Code style correction --- tensorflow/lite/micro/kernels/BUILD | 4 ++-- .../micro/kernels/dynamic_update_slice.cc | 21 +++++++++---------- .../lite/micro/kernels/dynamic_update_slice.h | 2 -- .../kernels/dynamic_update_slice_test.cc | 1 - 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 739dfc06b63..1e7ffed0720 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -261,7 +261,7 @@ tflm_kernel_cc_library( "dequantize_common.cc", "detection_postprocess.cc", "div.cc", - "dynamic_update_slice.cc", + "dynamic_update_slice.cc", "elementwise.cc", "elu.cc", "embedding_lookup.cc", @@ -351,7 +351,7 @@ tflm_kernel_cc_library( "decode_state_prune.h", "depthwise_conv.h", "dequantize.h", - "dynamic_update_slice.h", + "dynamic_update_slice.h", "ethosu.h", "fully_connected.h", "hard_swish.h", diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc index 0d8224497d4..52b1a853116 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -46,12 +46,12 @@ TfLiteStatus CalculateClampedStartIndices( // Recursive helper for N-dimensional slice update. template TfLiteStatus UpdateSliceRecursive(int current_dim, int max_dims, - const int32_t* output_strides, + const int32_t* output_strides, const int32_t* update_strides, - const int32_t* update_dims_data, - const T* update_tensor_data, - const int32_t* clamped_start_indices, - T* output_tensor_data) { + const int32_t* update_dims_data, + const T* update_tensor_data, + const int32_t* clamped_start_indices, + T* output_tensor_data) { if (current_dim == max_dims) { return kTfLiteOk; } @@ -65,8 +65,8 @@ TfLiteStatus UpdateSliceRecursive(int current_dim, int max_dims, } else { for (int i = 0; i < update_dims_data[current_dim]; ++i) { UpdateSliceRecursive(current_dim + 1, max_dims, output_strides, - update_strides, update_dims_data, - update_tensor_data, clamped_start_indices, + update_strides, update_dims_data, + update_tensor_data, clamped_start_indices, output_tensor_data); output_tensor_data += output_strides[current_dim]; @@ -216,13 +216,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (operand_eval->type) { case kTfLiteFloat32: return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + output_eval); case kTfLiteInt8: return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + output_eval); case kTfLiteInt32: return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + output_eval); default: MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.", TfLiteTypeGetName(operand_eval->type)); @@ -239,4 +239,3 @@ TFLMRegistration Register_DYNAMIC_UPDATE_SLICE() { } } // namespace tflite - diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.h b/tensorflow/lite/micro/kernels/dynamic_update_slice.h index 3f68e97aa93..89546110b72 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.h +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.h @@ -33,6 +33,4 @@ TFLMRegistration Register_DYNAMIC_UPDATE_SLICE(); } // namespace tflite - #endif // TENSORFLOW_LITE_MICRO_KERNELS_DYNAMIC_UPDATE_SLICE_H_ - diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index 631d86951b0..036cfe8031b 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -143,4 +143,3 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { output_data); } TF_LITE_MICRO_TESTS_END - From af40c6ecbf28fa1492c4e98e23dca26b4d82e1ff Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Wed, 29 Oct 2025 09:12:34 +0000 Subject: [PATCH 09/15] Replaced hard coded MaxDimensions to RuntimeShape::kMaxSmallSize --- tensorflow/lite/micro/kernels/dynamic_update_slice.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc index 52b1a853116..4a955052514 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tflite { -constexpr int kMaxDimensions = 6; +constexpr int kMaxDimensions = RuntimeShape::kMaxSmallSize; namespace { From bca3a54b77cecb0ea0a50c3f10bb24f509cd5125 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 14:02:39 +0000 Subject: [PATCH 10/15] 1. Added more test cases \n2.Removed unused code --- tensorflow/lite/micro/kernels/Makefile.inc | 1 + .../micro/kernels/dynamic_update_slice.cc | 80 ++++++------ .../kernels/dynamic_update_slice_test.cc | 119 ++++++++++++++---- 3 files changed, 128 insertions(+), 72 deletions(-) diff --git a/tensorflow/lite/micro/kernels/Makefile.inc b/tensorflow/lite/micro/kernels/Makefile.inc index b3bd47c8a17..78b851cb209 100644 --- a/tensorflow/lite/micro/kernels/Makefile.inc +++ b/tensorflow/lite/micro/kernels/Makefile.inc @@ -130,6 +130,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dequantize_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/div_test.cc \ +$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elementwise_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/elu_test.cc \ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/embedding_lookup_test.cc \ diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc index 4a955052514..42ccdd14842 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice.cc @@ -31,34 +31,30 @@ constexpr int kMaxDimensions = RuntimeShape::kMaxSmallSize; namespace { -TfLiteStatus CalculateClampedStartIndices( - int num_dims, const int64_t* raw_indices_data, - const int32_t* input_dims_data, const int32_t* update_dims_data, - int32_t* clamped_start_indices_output) { +void CalculateClampedStartIndices(int num_dims, const int64_t* raw_indices_data, + const int32_t* input_dims_data, + const int32_t* update_dims_data, + int32_t* clamped_start_indices_output) { for (int i = 0; i < num_dims; ++i) { clamped_start_indices_output[i] = static_cast( std::min(std::max(0, raw_indices_data[i]), input_dims_data[i] - update_dims_data[i])); } - return kTfLiteOk; + return; } // Recursive helper for N-dimensional slice update. template -TfLiteStatus UpdateSliceRecursive(int current_dim, int max_dims, - const int32_t* output_strides, - const int32_t* update_strides, - const int32_t* update_dims_data, - const T* update_tensor_data, - const int32_t* clamped_start_indices, - T* output_tensor_data) { - if (current_dim == max_dims) { - return kTfLiteOk; - } - +void UpdateSliceRecursive(int current_dim, int max_dims, + const int32_t* output_strides, + const int32_t* update_strides, + const int32_t* update_dims_data, + const T* update_tensor_data, + const int32_t* clamped_start_indices, + T* output_tensor_data) { + if (current_dim == max_dims) return; output_tensor_data += clamped_start_indices[current_dim] * output_strides[current_dim]; - if (current_dim == max_dims - 1) { std::memcpy(output_tensor_data, update_tensor_data, update_dims_data[max_dims - 1] * sizeof(T)); @@ -68,20 +64,17 @@ TfLiteStatus UpdateSliceRecursive(int current_dim, int max_dims, update_strides, update_dims_data, update_tensor_data, clamped_start_indices, output_tensor_data); - output_tensor_data += output_strides[current_dim]; update_tensor_data += update_strides[current_dim]; } } - return kTfLiteOk; } // Main dispatch function for Eval, templated on data type. template -TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, - const TfLiteEvalTensor* update_eval, - const int64_t* indices_eval, - TfLiteEvalTensor* output_eval) { +void EvalImpl(const TfLiteEvalTensor* operand_eval, + const TfLiteEvalTensor* update_eval, const int64_t* indices_eval, + TfLiteEvalTensor* output_eval) { const RuntimeShape operand_shape = tflite::micro::GetTensorShape(operand_eval); const RuntimeShape update_shape = tflite::micro::GetTensorShape(update_eval); @@ -92,16 +85,10 @@ TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, if (operand_shape.FlatSize() == update_shape.FlatSize()) { std::memcpy(output_tensor_data, update_tensor_data, ElementCount(*operand_eval->dims) * sizeof(T)); - return kTfLiteOk; - } - - if (num_dims > kMaxDimensions) { - MicroPrintf( - "DYNAMIC_UPDATE_SLICE: Operand rank %d exceeds max supported %d.", - num_dims, kMaxDimensions); - return kTfLiteError; + return; } + // If the operation is not done in-place, copy the input data to the output. if (operand_eval->data.data != output_eval->data.data) { std::memcpy(output_eval->data.data, operand_eval->data.data, ElementCount(*operand_eval->dims) * sizeof(T)); @@ -109,14 +96,13 @@ TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, // If update tensor is empty, no actual update is needed after operand copy. if (ElementCount(*update_eval->dims) == 0) { - return kTfLiteOk; + return; } // Calculate clamped start indices (stack-allocated) int32_t clamped_start_indices[kMaxDimensions]; - TF_LITE_ENSURE_STATUS(CalculateClampedStartIndices( - num_dims, indices_eval, operand_shape.DimsData(), update_shape.DimsData(), - clamped_start_indices)); + CalculateClampedStartIndices(num_dims, indices_eval, operand_shape.DimsData(), + update_shape.DimsData(), clamped_start_indices); // Calculate strides (stack-allocated) int32_t output_stride[kMaxDimensions]; @@ -130,7 +116,7 @@ TfLiteStatus EvalImpl(const TfLiteEvalTensor* operand_eval, // Perform the N-dimensional update // The recursive function needs base pointers and initial offsets. - return UpdateSliceRecursive( + UpdateSliceRecursive( /*current_dim=*/0, num_dims, output_stride, update_stride, update_shape.DimsData(), update_tensor_data, clamped_start_indices, output_tensor_data); @@ -174,14 +160,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, SizeOfDimension(update, i) <= SizeOfDimension(operand, i)); } - output->type = operand->type; // Deallocate temporary tensors micro_context->DeallocateTempTfLiteTensor(operand); micro_context->DeallocateTempTfLiteTensor(update); micro_context->DeallocateTempTfLiteTensor(start_indices); - micro_context->DeallocateTempTfLiteTensor( - output); // Output tensor metadata also temp + micro_context->DeallocateTempTfLiteTensor(output); return kTfLiteOk; } @@ -215,14 +199,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // Dispatch based on tensor type switch (operand_eval->type) { case kTfLiteFloat32: - return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + EvalImpl(operand_eval, update_eval, indices_data_i64, output_eval); + break; case kTfLiteInt8: - return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; + case kTfLiteInt16: + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; case kTfLiteInt32: - return EvalImpl(operand_eval, update_eval, indices_data_i64, - output_eval); + EvalImpl(operand_eval, update_eval, indices_data_i64, + output_eval); + break; default: MicroPrintf("DYNAMIC_UPDATE_SLICE: Operand type %s not supported.", TfLiteTypeGetName(operand_eval->type)); diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index 036cfe8031b..86283579e91 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -26,7 +26,6 @@ namespace tflite { namespace testing { namespace { -// constexpr float kTestTolerance = 7.41e-03; constexpr int kNumInputs = 3; constexpr int kNumOutputs = 1; constexpr int kInputTensorIndex_0 = 0; @@ -34,16 +33,6 @@ constexpr int kInputTensorIndex_1 = 1; constexpr int kInputTensorIndex_2 = 2; constexpr int kOutputTensorIndex = 3; -// min/max are used to compute scale, zero-point is 0 -template -struct TestDynamicUpdateSliceParams { - // quantization parameters - float data_min; // input data minimum value - float data_max; // input data maximum value - int8_t input1_data[kInputSize]; // quantized input storage - int8_t input2_data[kInputSize]; // quantized input storage -}; - void ExecuteDynamicUpdateSliceTest(TfLiteTensor* tensors, int tensors_count) { int kInputArrayData[] = {kNumInputs, kInputTensorIndex_0, kInputTensorIndex_1, kInputTensorIndex_2}; @@ -59,12 +48,11 @@ void ExecuteDynamicUpdateSliceTest(TfLiteTensor* tensors, int tensors_count) { TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke()); } -void TestDynamicUpdateSliceFloat(int* input_dims_data[kNumInputs], - const float* input_data_0, - const float* input_data_1, - const int32_t* input_data_2, - const float* golden_data, int* expected_dims, - float* output_data) { +template +void TestDynamicUpdateSlice(int* input_dims_data[kNumInputs], + const T* input_data_0, const T* input_data_1, + const U* input_data_2, const T* golden_data, + int* expected_dims, T* output_data) { TfLiteIntArray* input_dims_0 = IntArrayFromInts(input_dims_data[0]); TfLiteIntArray* input_dims_1 = IntArrayFromInts(input_dims_data[1]); TfLiteIntArray* input_dims_2 = IntArrayFromInts(input_dims_data[2]); @@ -82,8 +70,7 @@ void TestDynamicUpdateSliceFloat(int* input_dims_data[kNumInputs], // check output data against expected for (int i = 0; i < output_count; i++) { - printf("output_data[%d] = %f\n", i, output_data[i]); - TF_LITE_MICRO_EXPECT_NEAR(golden_data[i], output_data[i], 0.0); + TF_LITE_MICRO_EXPECT_EQ(golden_data[i], output_data[i]); } // check output dimensions (relocated) against original dimensions @@ -95,8 +82,6 @@ void TestDynamicUpdateSliceFloat(int* input_dims_data[kNumInputs], } } -// TODO(rameshkunasi): Add quantized test for dynamic update slice. - } // namespace } // namespace testing } // namespace tflite @@ -118,9 +103,89 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSliceFloat(kInputDims, kInput_0, kInput_1, - kInput_2, kExpect, kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 1}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int8_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int8_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 1}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int16_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int16_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int16_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int16_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 1}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int32_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int32_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int32_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int32_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); +} + +TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { + int32_t kInputDims_0[] = {2, 3, 3}; + int32_t kInputDims_1[] = {2, 2, 1}; + int32_t kInputDims_2[] = {1, 2}; + int32_t* kInputDims[tflite::testing::kNumInputs] = { + kInputDims_0, kInputDims_1, kInputDims_2}; + int kOutputDims[] = {2, 3, 3}; + + constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int8_t kInput_1[] = {-1, -2}; + constexpr int64_t kInput_2[] = {1, 1}; + constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kOutputCount = std::extent::value; + int8_t output_data[kOutputCount]; + + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { @@ -138,8 +203,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSliceFloat(kInputDims, kInput_0, kInput_1, - kInput_2, kExpect, kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice(kInputDims, kInput_0, kInput_1, + kInput_2, kExpect, kOutputDims, + output_data); } TF_LITE_MICRO_TESTS_END From 93b534b78436dbd26047a73c29764f60bf8c434d Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 14:23:01 +0000 Subject: [PATCH 11/15] Updates for test failure on ARM --- .../kernels/dynamic_update_slice_test.cc | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index 86283579e91..ebdf8506085 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -89,16 +89,16 @@ void TestDynamicUpdateSlice(int* input_dims_data[kNumInputs], TF_LITE_MICRO_TESTS_BEGIN TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 1}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr float kInput_1[] = {-1, -2}; - constexpr int32_t kInput_2[] = {1, 1}; + constexpr int kInput_2[] = {1, 1}; constexpr float kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; @@ -109,30 +109,30 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 1}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr int8_t kInput_1[] = {-1, -2}; - constexpr int32_t kInput_2[] = {1, 1}; + constexpr int kInput_2[] = {1, 1}; constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; int8_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 1}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; @@ -143,36 +143,36 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { constexpr int kOutputCount = std::extent::value; int16_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 1}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; - constexpr int32_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - constexpr int32_t kInput_1[] = {-1, -2}; - constexpr int32_t kInput_2[] = {1, 1}; - constexpr int32_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int kInput_1[] = {-1, -2}; + constexpr int kInput_2[] = {1, 1}; + constexpr int kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; - int32_t output_data[kOutputCount]; + int output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 1}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 1}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; @@ -189,16 +189,16 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { - int32_t kInputDims_0[] = {2, 3, 3}; - int32_t kInputDims_1[] = {2, 2, 2}; - int32_t kInputDims_2[] = {1, 2}; - int32_t* kInputDims[tflite::testing::kNumInputs] = { + int kInputDims_0[] = {2, 3, 3}; + int kInputDims_1[] = {2, 2, 2}; + int kInputDims_2[] = {1, 2}; + int* kInputDims[tflite::testing::kNumInputs] = { kInputDims_0, kInputDims_1, kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr float kInput_1[] = {-1, -2, -3, -4}; - constexpr int32_t kInput_2[] = {2, 2}; + constexpr int kInput_2[] = {2, 2}; constexpr float kExpect[] = {1, 2, 3, 4, -1, -2, 7, -3, -4}; constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; From aa4aa15986e091b07d2fafe6563ca70ea59161ce Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 14:33:34 +0000 Subject: [PATCH 12/15] Code style updates --- .../kernels/dynamic_update_slice_test.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index ebdf8506085..f0dfea5ca36 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -92,8 +92,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 1}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -112,8 +112,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 1}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -132,8 +132,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 1}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr int16_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -152,8 +152,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 1}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr int kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -172,8 +172,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 1}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; @@ -192,8 +192,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { int kInputDims_0[] = {2, 3, 3}; int kInputDims_1[] = {2, 2, 2}; int kInputDims_2[] = {1, 2}; - int* kInputDims[tflite::testing::kNumInputs] = { - kInputDims_0, kInputDims_1, kInputDims_2}; + int* kInputDims[tflite::testing::kNumInputs] = {kInputDims_0, kInputDims_1, + kInputDims_2}; int kOutputDims[] = {2, 3, 3}; constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; From 4c3b848d81a23a16b4d380cf1ab5cefc79f4b0e7 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 15:05:37 +0000 Subject: [PATCH 13/15] Updates on test case failure for ARM --- tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index f0dfea5ca36..ed3a9758849 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -103,7 +103,7 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } @@ -138,7 +138,7 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { constexpr int16_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr int16_t kInput_1[] = {-1, -2}; - constexpr int32_t kInput_2[] = {1, 1}; + constexpr int kInput_2[] = {1, 1}; constexpr int16_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; int16_t output_data[kOutputCount]; @@ -163,7 +163,7 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { constexpr int kOutputCount = std::extent::value; int output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } From 831dc33fadc6ed191941d3f5bed1672a9a578b40 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 16:30:07 +0000 Subject: [PATCH 14/15] Updates on test case failure for ARM --- .../kernels/dynamic_update_slice_test.cc | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index ed3a9758849..b5f5df74a73 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -98,12 +98,12 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleFloat) { constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr float kInput_1[] = {-1, -2}; - constexpr int kInput_2[] = {1, 1}; + constexpr int32_t kInput_2[] = {1, 1}; constexpr float kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } @@ -118,12 +118,12 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8) { constexpr int8_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr int8_t kInput_1[] = {-1, -2}; - constexpr int kInput_2[] = {1, 1}; + constexpr int32_t kInput_2[] = {1, 1}; constexpr int8_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; int8_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } @@ -138,12 +138,12 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt16) { constexpr int16_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr int16_t kInput_1[] = {-1, -2}; - constexpr int kInput_2[] = {1, 1}; + constexpr int32_t kInput_2[] = {1, 1}; constexpr int16_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; int16_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( + tflite::testing::TestDynamicUpdateSlice( kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, output_data); } @@ -156,16 +156,18 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { kInputDims_2}; int kOutputDims[] = {2, 3, 3}; - constexpr int kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - constexpr int kInput_1[] = {-1, -2}; - constexpr int kInput_2[] = {1, 1}; - constexpr int kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; + constexpr int32_t kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + constexpr int32_t kInput_1[] = {-1, -2}; + constexpr int32_t kInput_2[] = {1, 1}; + constexpr int32_t kExpect[] = {1, 2, 3, 4, -1, 6, 7, -2, 9}; constexpr int kOutputCount = std::extent::value; - int output_data[kOutputCount]; + int32_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice( - kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice(kInputDims, + kInput_0, kInput_1, + kInput_2, kExpect, + kOutputDims, + output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { @@ -198,13 +200,14 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { constexpr float kInput_0[] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; constexpr float kInput_1[] = {-1, -2, -3, -4}; - constexpr int kInput_2[] = {2, 2}; + constexpr int32_t kInput_2[] = {2, 2}; constexpr float kExpect[] = {1, 2, 3, 4, -1, -2, 7, -3, -4}; constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice(kInputDims, kInput_0, kInput_1, - kInput_2, kExpect, kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice(kInputDims, kInput_0, + kInput_1, kInput_2, + kExpect, kOutputDims, + output_data); } TF_LITE_MICRO_TESTS_END From 3b752b53d489c0b4c1e3cb5e415cee0e5de49df5 Mon Sep 17 00:00:00 2001 From: Ramesh Kunasi Date: Tue, 4 Nov 2025 17:03:34 +0000 Subject: [PATCH 15/15] Code style updates --- .../micro/kernels/dynamic_update_slice_test.cc | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc index b5f5df74a73..0bfd6c89740 100644 --- a/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc +++ b/tensorflow/lite/micro/kernels/dynamic_update_slice_test.cc @@ -163,11 +163,9 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt32) { constexpr int kOutputCount = std::extent::value; int32_t output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice(kInputDims, - kInput_0, kInput_1, - kInput_2, kExpect, - kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); } TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestSimpleInt8IndicesI64) { @@ -205,9 +203,8 @@ TF_LITE_MICRO_TEST(DynamicUpdateSliceOpTestBoundaryTest) { constexpr int kOutputCount = std::extent::value; float output_data[kOutputCount]; - tflite::testing::TestDynamicUpdateSlice(kInputDims, kInput_0, - kInput_1, kInput_2, - kExpect, kOutputDims, - output_data); + tflite::testing::TestDynamicUpdateSlice( + kInputDims, kInput_0, kInput_1, kInput_2, kExpect, kOutputDims, + output_data); } TF_LITE_MICRO_TESTS_END