diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc index c6459cfcc1e..06e645e052a 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc @@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() { // Multi-batch for time_major input RuntimeShape LstmStepManager::InputShape() const { int batch_size = 1; - if (size_info_.time_major) { + if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) { batch_size = size_info_.batch_size; } const int dims[2] = {batch_size, size_info_.input_dimension}; @@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const { // Multi-batch for time_major input RuntimeShape LstmStepManager::StateShape() const { int batch_size = 1; - if (size_info_.time_major) { + if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) { batch_size = size_info_.batch_size; } const int dims[2] = {batch_size, size_info_.state_dimension}; diff --git a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h index 0ba5e22a083..67b928a4381 100644 --- a/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h +++ b/tensorflow/lite/micro/kernels/xtensa/lstm_eval.h @@ -661,10 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data, kernel_content.GetInternalTensor(tflite::kLstmInputTensor); TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor(); - int time_major = step_info.time_major(); - int num_batches = time_major == 0 ? 1 : step_info.batch_size(); - int input_dimension = step_info.input_dimension(); - int state_dimension = step_info.state_dimension(); + const auto& size_info = op_data.size_info; + const int time_major = step_info.time_major(); + const int batch_size = size_info.batch_size; + const int time_steps = size_info.time_steps; + const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1) + : step_info.batch_size(); + const int input_dimension = step_info.input_dimension(); + const int state_dimension = step_info.state_dimension(); // Check offset validity to avoid memory overflow TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension, @@ -803,8 +807,11 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data, // prepare for the next time step step_info.UpdateTime(); } + } else if(size_info.batch_size > 1 && size_info.time_steps == 1) { + // Ramesh + lstm_internal::LstmStep( + step_info, op_data, kernel_content, buffers); } else { - // batch first, unable to size the input data. single batch inference for (int b = 0; b < size_info.batch_size; b++) { for (int t = 0; t < size_info.time_steps; t++) { lstm_internal::LstmStep(