From c12210579cb385e991287a33638d365e6d0fc587 Mon Sep 17 00:00:00 2001 From: pramods-cad Date: Mon, 27 Oct 2025 08:22:34 -0700 Subject: [PATCH 1/2] Optimization in LSTM for batch > 1 cases on HiFi. --- .../lite/micro/kernels/xtensa/lstm_eval.cc | 4 +-- .../lite/micro/kernels/xtensa/lstm_eval.h | 29 +++++++++++++------ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc index c6459cfcc1e..06e645e052a 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc @@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() { // Multi-batch for time_major input RuntimeShape LstmStepManager::InputShape() const { int batch_size = 1; - if (size_info_.time_major) { + if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) { batch_size = size_info_.batch_size; } const int dims[2] = {batch_size, size_info_.input_dimension}; @@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const { // Multi-batch for time_major input RuntimeShape LstmStepManager::StateShape() const { int batch_size = 1; - if (size_info_.time_major) { + if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) { batch_size = size_info_.batch_size; } const int dims[2] = {batch_size, size_info_.state_dimension}; diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h index 0ba5e22a083..5b1934a95af 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h @@ -666,6 +666,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, int input_dimension = step_info.input_dimension(); int state_dimension = step_info.state_dimension(); + const auto& size_info = op_data.size_info; + if(size_info.batch_size > 1 && size_info.time_steps == 1) { + num_batches = size_info.batch_size; + } + // Check offset validity to avoid memory overflow TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension, tflite::micro::GetTensorShape(input).FlatSize()); @@ -805,16 +810,22 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data, } } else { // batch first, unable to size the input data. single batch inference - for (int b = 0; b < size_info.batch_size; b++) { - for (int t = 0; t < size_info.time_steps; t++) { - lstm_internal::LstmStep( - step_info, op_data, kernel_content, buffers); - // prepare for the next time step - step_info.UpdateTime(); + if(size_info.batch_size > 1 && size_info.time_steps == 1) { + // Ramesh + lstm_internal::LstmStep( + step_info, op_data, kernel_content, buffers); + } else { + for (int b = 0; b < size_info.batch_size; b++) { + for (int t = 0; t < size_info.time_steps; t++) { + lstm_internal::LstmStep( + step_info, op_data, kernel_content, buffers); + // prepare for the next time step + step_info.UpdateTime(); + } + // prepare for the next batch + step_info.UpdateBatch(); + step_info.ResetTime(); } - // prepare for the next batch - step_info.UpdateBatch(); - step_info.ResetTime(); } } return kTfLiteOk; From c98ec2884f2c6fbc0a7624911df50f69d064eb0e Mon Sep 17 00:00:00 2001 From: pramods-cad Date: Fri, 31 Oct 2025 08:19:58 -0700 Subject: [PATCH 2/2] Addressed review comments. --- .../lite/micro/kernels/xtensa/lstm_eval.h | 44 +++++++++---------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h index 5b1934a95af..67b928a4381 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h @@ -661,15 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, kernel_content.GetInternalTensor(tflite::kLstmInputTensor); TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor(); - int time_major = step_info.time_major(); - int num_batches = time_major == 0 ? 1 : step_info.batch_size(); - int input_dimension = step_info.input_dimension(); - int state_dimension = step_info.state_dimension(); - const auto& size_info = op_data.size_info; - if(size_info.batch_size > 1 && size_info.time_steps == 1) { - num_batches = size_info.batch_size; - } + const int time_major = step_info.time_major(); + const int batch_size = size_info.batch_size; + const int time_steps = size_info.time_steps; + const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1) + : step_info.batch_size(); + const int input_dimension = step_info.input_dimension(); + const int state_dimension = step_info.state_dimension(); // Check offset validity to avoid memory overflow TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension, @@ -808,24 +807,21 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data, // prepare for the next time step step_info.UpdateTime(); } + } else if(size_info.batch_size > 1 && size_info.time_steps == 1) { + // Ramesh + lstm_internal::LstmStep( + step_info, op_data, kernel_content, buffers); } else { - // batch first, unable to size the input data. single batch inference - if(size_info.batch_size > 1 && size_info.time_steps == 1) { - // Ramesh - lstm_internal::LstmStep( - step_info, op_data, kernel_content, buffers); - } else { - for (int b = 0; b < size_info.batch_size; b++) { - for (int t = 0; t < size_info.time_steps; t++) { - lstm_internal::LstmStep( - step_info, op_data, kernel_content, buffers); - // prepare for the next time step - step_info.UpdateTime(); - } - // prepare for the next batch - step_info.UpdateBatch(); - step_info.ResetTime(); + for (int b = 0; b < size_info.batch_size; b++) { + for (int t = 0; t < size_info.time_steps; t++) { + lstm_internal::LstmStep( + step_info, op_data, kernel_content, buffers); + // prepare for the next time step + step_info.UpdateTime(); } + // prepare for the next batch + step_info.UpdateBatch(); + step_info.ResetTime(); } } return kTfLiteOk;