Skip to content

Commit d3d4501

Browse files
lantigaalonre24
andauthored
Fix TfLite issue (#701)
Co-authored-by: alonre24 <[email protected]>
1 parent d90934b commit d3d4501

File tree

5 files changed

+132
-17
lines changed

5 files changed

+132
-17
lines changed

src/backends/libtflite_c/tflite_c.cpp

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "redismodule.h"
55
#include "tensorflow/lite/model.h"
66
#include "tensorflow/lite/interpreter.h"
7+
#include "tensorflow/lite/util.h"
78
#include "tensorflow/lite/kernels/register.h"
89
#include "tensorflow/lite/tools/evaluation/utils.h"
910

@@ -52,6 +53,7 @@ static DLDataType getDLDataType(const TfLiteTensor *tensor) {
5253
return dtype;
5354
}
5455

56+
5557
static DLDevice getDLDevice(const TfLiteTensor *tensor, const int64_t &device_id) {
5658
DLDevice device;
5759
device.device_id = device_id;
@@ -75,29 +77,52 @@ size_t dltensorBytes(DLManagedTensor *t) {
7577
void copyToTfLiteTensor(std::shared_ptr<tflite::Interpreter> interpreter, int tflite_input,
7678
DLManagedTensor *input) {
7779
TfLiteTensor *tensor = interpreter->tensor(tflite_input);
78-
7980
size_t nbytes = dltensorBytes(input);
81+
DLDataType dltensor_type = input->dl_tensor.dtype;
82+
const char *type_mismatch_msg = "Input tensor type doesn't match the type expected"
83+
" by the model definition";
8084

8185
switch (tensor->type) {
8286
case kTfLiteUInt8:
83-
memcpy(interpreter->typed_tensor<uint8_t>(tflite_input), input->dl_tensor.data, nbytes);
84-
break;
87+
if (dltensor_type.code != kDLUInt || dltensor_type.bits != 8) {
88+
throw std::logic_error(type_mismatch_msg);
89+
}
90+
memcpy(interpreter->typed_tensor<uint8_t>(tflite_input), input->dl_tensor.data, nbytes);
91+
break;
8592
case kTfLiteInt64:
93+
if (dltensor_type.code != kDLInt || dltensor_type.bits != 64) {
94+
throw std::logic_error(type_mismatch_msg);
95+
}
8696
memcpy(interpreter->typed_tensor<int64_t>(tflite_input), input->dl_tensor.data, nbytes);
8797
break;
8898
case kTfLiteInt32:
99+
if (dltensor_type.code != kDLInt || dltensor_type.bits != 32) {
100+
throw std::logic_error(type_mismatch_msg);
101+
}
89102
memcpy(interpreter->typed_tensor<int32_t>(tflite_input), input->dl_tensor.data, nbytes);
90103
break;
91104
case kTfLiteInt16:
105+
if (dltensor_type.code != kDLInt || dltensor_type.bits != 16) {
106+
throw std::logic_error(type_mismatch_msg);
107+
}
92108
memcpy(interpreter->typed_tensor<int16_t>(tflite_input), input->dl_tensor.data, nbytes);
93109
break;
94110
case kTfLiteInt8:
111+
if (dltensor_type.code != kDLInt || dltensor_type.bits != 8) {
112+
throw std::logic_error(type_mismatch_msg);
113+
}
95114
memcpy(interpreter->typed_tensor<int8_t>(tflite_input), input->dl_tensor.data, nbytes);
96115
break;
97116
case kTfLiteFloat32:
117+
if (dltensor_type.code != kDLFloat || dltensor_type.bits != 32) {
118+
throw std::logic_error(type_mismatch_msg);
119+
}
98120
memcpy(interpreter->typed_tensor<float>(tflite_input), input->dl_tensor.data, nbytes);
99121
break;
100122
case kTfLiteBool:
123+
if (dltensor_type.code != kDLBool || dltensor_type.bits != 8) {
124+
throw std::logic_error(type_mismatch_msg);
125+
}
101126
memcpy(interpreter->typed_tensor<bool>(tflite_input), input->dl_tensor.data, nbytes);
102127
case kTfLiteFloat16:
103128
throw std::logic_error("Float16 not currently supported as input tensor data type");
@@ -318,6 +343,38 @@ extern "C" void tfliteRunModel(void *ctx, long n_inputs, DLManagedTensor **input
318343
return;
319344
}
320345

346+
// NOTE: TFLITE requires all tensors in the graph to be explicitly
347+
// preallocated before input tensors are memcopied. These are cached
348+
// in the session, so we need to check if for instance the batch size
349+
// has changed or the shape has changed in general compared to the
350+
// previous run and in that case we resize input tensors and call
351+
// the AllocateTensor function manually.
352+
bool need_reallocation = false;
353+
std::vector<int> dims;
354+
for (size_t i = 0; i < tflite_inputs.size(); i++) {
355+
const TfLiteTensor* tflite_tensor = interpreter->tensor(tflite_inputs[i]);
356+
int64_t ndim = inputs[i]->dl_tensor.ndim;
357+
int64_t *shape = inputs[i]->dl_tensor.shape;
358+
dims.resize(ndim);
359+
for (size_t j=0; j < ndim; j++) {
360+
dims[j] = shape[j];
361+
}
362+
if (!tflite::EqualArrayAndTfLiteIntArray(tflite_tensor->dims, dims.size(), dims.data())) {
363+
if (interpreter->ResizeInputTensor(i, dims) != kTfLiteOk) {
364+
_setError("Failed to resize input tensors", error);
365+
return;
366+
}
367+
need_reallocation = true;
368+
}
369+
}
370+
371+
if (need_reallocation) {
372+
if (interpreter->AllocateTensors() != kTfLiteOk) {
373+
_setError("Failed to allocate tensors", error);
374+
return;
375+
}
376+
}
377+
321378
try {
322379
for (size_t i = 0; i < tflite_inputs.size(); i++) {
323380
copyToTfLiteTensor(interpreter, tflite_inputs[i], inputs[i]);

src/redisai.c

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,6 @@ int RedisAI_ModelStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **arg
201201

202202
unsigned long long batchsize = 0;
203203
if (AC_AdvanceIfMatch(&ac, "BATCHSIZE")) {
204-
if (backend == RAI_BACKEND_TFLITE) {
205-
return RedisModule_ReplyWithError(
206-
ctx, "ERR Auto-batching not supported by the TFLITE backend");
207-
}
208204
if (AC_GetUnsignedLongLong(&ac, &batchsize, 0) != AC_OK) {
209205
return RedisModule_ReplyWithError(ctx, "ERR Invalid argument for BATCHSIZE");
210206
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:9df60af7ab24d287a54668e845ea7da1c854086b828a4a4cf46c55c403095053
3+
size 10209756

tests/flow/tests_tensorflow.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,6 @@ def test_tensorflow_modelexecute_script_execute_resnet(env):
646646
inputvar = 'images'
647647
outputvar = 'output'
648648

649-
650649
model_pb, script, labels, img = load_resnet_test_data()
651650

652651
ret = con.execute_command('AI.MODELSTORE', model_name, 'TF', DEVICE,

tests/flow/tests_tflite.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import redis
1+
import numpy as np
22

33
from includes import *
44

@@ -47,7 +47,61 @@ def test_run_tflite_model(env):
4747
env.assertEqual(values[0], 1)
4848

4949

50-
def test_run_tflite_model_errors(env):
50+
def test_run_tflite_model_autobatch(env):
51+
if not TEST_TFLITE:
52+
env.debugPrint("skipping {} since TEST_TFLITE=0".format(sys._getframe().f_code.co_name), force=True)
53+
return
54+
55+
con = env.getConnection()
56+
model_pb = load_file_content('lite-model_imagenet_mobilenet_v3_small_100_224_classification_5_default_1.tflite')
57+
_, _, _, img = load_resnet_test_data()
58+
img = img.astype(np.float32) / 255
59+
60+
ret = con.execute_command('AI.MODELSTORE', 'm{1}', 'TFLITE', 'CPU',
61+
'BATCHSIZE', 4, 'MINBATCHSIZE', 2,
62+
'BLOB', model_pb)
63+
env.assertEqual(ret, b'OK')
64+
65+
ret = con.execute_command('AI.MODELGET', 'm{1}', 'META')
66+
env.assertEqual(len(ret), 16)
67+
if DEVICE == "CPU":
68+
env.assertEqual(ret[1], b'TFLITE')
69+
env.assertEqual(ret[3], b'CPU')
70+
71+
ret = con.execute_command('AI.TENSORSET', 'a{1}',
72+
'FLOAT', 1, img.shape[1], img.shape[0], 3,
73+
'BLOB', img.tobytes())
74+
env.assertEqual(ret, b'OK')
75+
76+
ret = con.execute_command('AI.TENSORSET', 'b{1}',
77+
'FLOAT', 1, img.shape[1], img.shape[0], 3,
78+
'BLOB', img.tobytes())
79+
env.assertEqual(ret, b'OK')
80+
81+
def run():
82+
con = env.getConnection()
83+
con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1,
84+
'b{1}', 'OUTPUTS', 1, 'd{1}')
85+
ensureSlaveSynced(con, env)
86+
87+
t = threading.Thread(target=run)
88+
t.start()
89+
90+
con.execute_command('AI.MODELEXECUTE', 'm{1}', 'INPUTS', 1, 'a{1}', 'OUTPUTS', 1, 'c{1}')
91+
t.join()
92+
93+
ensureSlaveSynced(con, env)
94+
95+
values = con.execute_command('AI.TENSORGET', 'c{1}', 'VALUES')
96+
idx = np.argmax(values)
97+
env.assertEqual(idx, 112)
98+
99+
values = con.execute_command('AI.TENSORGET', 'd{1}', 'VALUES')
100+
idx = np.argmax(values)
101+
env.assertEqual(idx, 112)
102+
103+
104+
def test_run_tflite_errors(env):
51105
if not TEST_TFLITE:
52106
env.debugPrint("skipping {} since TEST_TFLITE=0".format(sys._getframe().f_code.co_name), force=True)
53107
return
@@ -64,13 +118,6 @@ def test_run_tflite_model_errors(env):
64118
check_error_message(env, con, "Failed to load model from buffer",
65119
'AI.MODELSTORE', 'm{1}', 'TFLITE', 'CPU', 'TAG', 'asdf', 'BLOB', wrong_model_pb)
66120

67-
# TODO: Autobatch is tricky with TFLITE because TFLITE expects a fixed batch
68-
# size. At least we should constrain MINBATCHSIZE according to the
69-
# hard-coded dims in the tflite model.
70-
check_error_message(env, con, "Auto-batching not supported by the TFLITE backend",
71-
'AI.MODELSTORE', 'm{1}', 'TFLITE', 'CPU',
72-
'BATCHSIZE', 2, 'MINBATCHSIZE', 2, 'BLOB', model_pb)
73-
74121
ret = con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
75122
env.assertEqual(ret, b'OK')
76123

@@ -82,6 +129,19 @@ def test_run_tflite_model_errors(env):
82129
check_error_message(env, con, "Number of keys given as INPUTS here does not match model definition",
83130
'AI.MODELEXECUTE', 'm_2{1}', 'INPUTS', 3, 'a{1}', 'b{1}', 'c{1}', 'OUTPUTS', 1, 'd{1}')
84131

132+
model_pb = load_file_content('lite-model_imagenet_mobilenet_v3_small_100_224_classification_5_default_1.tflite')
133+
_, _, _, img = load_resnet_test_data()
134+
135+
ret = con.execute_command('AI.MODELSTORE', 'image_net{1}', 'TFLITE', 'CPU', 'BLOB', model_pb)
136+
env.assertEqual(ret, b'OK')
137+
ret = con.execute_command('AI.TENSORSET', 'dog{1}', 'UINT8', 1, img.shape[1], img.shape[0], 3,
138+
'BLOB', img.tobytes())
139+
env.assertEqual(ret, b'OK')
140+
141+
# The model expects FLOAT input, but UINT8 tensor is given.
142+
check_error_message(env, con, "Input tensor type doesn't match the type expected by the model definition",
143+
'AI.MODELEXECUTE', 'image_net{1}', 'INPUTS', 1, 'dog{1}', 'OUTPUTS', 1, 'output{1}')
144+
85145

86146
def test_tflite_modelinfo(env):
87147
if not TEST_TFLITE:

0 commit comments

Comments
 (0)