Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::InputShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.input_dimension};
Expand All @@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::StateShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.state_dimension};
Expand Down
17 changes: 12 additions & 5 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,14 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
kernel_content.GetInternalTensor(tflite::kLstmInputTensor);
TfLiteEvalTensor* recurrent = kernel_content.HiddenStateTensor();

int time_major = step_info.time_major();
int num_batches = time_major == 0 ? 1 : step_info.batch_size();
int input_dimension = step_info.input_dimension();
int state_dimension = step_info.state_dimension();
const auto& size_info = op_data.size_info;
const int time_major = step_info.time_major();
const int batch_size = size_info.batch_size;
const int time_steps = size_info.time_steps;
const int num_batches = time_major == 0 ? (time_steps == 1 ? batch_size : 1)
: step_info.batch_size();
const int input_dimension = step_info.input_dimension();
const int state_dimension = step_info.state_dimension();

// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
Expand Down Expand Up @@ -803,8 +807,11 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
// prepare for the next time step
step_info.UpdateTime();
}
} else if(size_info.batch_size > 1 && size_info.time_steps == 1) {
// Ramesh
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
} else {
// batch first, unable to size the input data. single batch inference
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
Expand Down
Loading