Skip to content

Commit 9346ca6

Browse files
committed
Disable support for autobatching for TFLITE
1 parent d0962f8 commit 9346ca6

File tree

2 files changed

+52
-41
lines changed

2 files changed

+52
-41
lines changed

src/redisai.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,13 +667,19 @@ int RedisAI_ModelSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
667667

668668
unsigned long long batchsize = 0;
669669
if (AC_AdvanceIfMatch(&ac, "BATCHSIZE")) {
670+
if (backend == RAI_BACKEND_TFLITE) {
671+
return RedisModule_ReplyWithError(ctx, "Auto-batching not supported by the TFLITE backend.");
672+
}
670673
if (AC_GetUnsignedLongLong(&ac, &batchsize, 0) != AC_OK) {
671674
return RedisModule_ReplyWithError(ctx, "Invalid argument for BATCHSIZE.");
672675
}
673676
}
674677

675678
unsigned long long minbatchsize = 0;
676679
if (AC_AdvanceIfMatch(&ac, "MINBATCHSIZE")) {
680+
if (batchsize == 0) {
681+
return RedisModule_ReplyWithError(ctx, "MINBATCHSIZE specified without BATCHSIZE.");
682+
}
677683
if (AC_GetUnsignedLongLong(&ac, &minbatchsize, 0) != AC_OK) {
678684
return RedisModule_ReplyWithError(ctx, "Invalid argument for MINBATCHSIZE");
679685
}

test/basic_tests.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -801,47 +801,52 @@ def test_run_tflite_model(env):
801801
# TODO: Autobatch is tricky with TFLITE because TFLITE expects a fixed batch
802802
# size. At least we should constrain MINBATCHSIZE according to the
803803
# hard-coded dims in the tflite model.
804-
# def test_run_tflite_model_autobatch(env):
805-
# if not TEST_PT:
806-
# return
807-
#
808-
# con = env.getConnection()
809-
#
810-
# test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
811-
# model_filename = os.path.join(test_data_path, 'mnist_model_quant.tflite')
812-
# sample_filename = os.path.join(test_data_path, 'one.raw')
813-
#
814-
# with open(model_filename, 'rb') as f:
815-
# model_pb = f.read()
816-
#
817-
# with open(sample_filename, 'rb') as f:
818-
# sample_raw = f.read()
819-
#
820-
# ret = con.execute_command('AI.MODELSET', 'm', 'TFLITE', 'CPU',
821-
# 'BATCHSIZE', 2, 'MINBATCHSIZE', 2, model_pb)
822-
# env.assertEqual(ret, b'OK')
823-
#
824-
# con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
825-
# con.execute_command('AI.TENSORSET', 'c', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
826-
#
827-
# def run():
828-
# con = env.getConnection()
829-
# con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'c', 'OUTPUTS', 'd', 'd2')
830-
#
831-
# t = threading.Thread(target=run)
832-
# t.start()
833-
#
834-
# con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'OUTPUTS', 'b', 'b2')
835-
#
836-
# tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES')
837-
# value = tensor[-1][0]
838-
#
839-
# env.assertEqual(value, 1)
840-
#
841-
# tensor = con.execute_command('AI.TENSORGET', 'd', 'VALUES')
842-
# value = tensor[-1][0]
843-
#
844-
# env.assertEqual(value, 1)
804+
def test_run_tflite_model_autobatch(env):
805+
if not TEST_PT:
806+
return
807+
808+
con = env.getConnection()
809+
810+
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
811+
model_filename = os.path.join(test_data_path, 'mnist_model_quant.tflite')
812+
sample_filename = os.path.join(test_data_path, 'one.raw')
813+
814+
with open(model_filename, 'rb') as f:
815+
model_pb = f.read()
816+
817+
with open(sample_filename, 'rb') as f:
818+
sample_raw = f.read()
819+
820+
try:
821+
ret = con.execute_command('AI.MODELSET', 'm', 'TFLITE', 'CPU',
822+
'BATCHSIZE', 2, 'MINBATCHSIZE', 2, model_pb)
823+
except Exception as e:
824+
exception = e
825+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
826+
827+
# env.assertEqual(ret, b'OK')
828+
829+
# con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
830+
# con.execute_command('AI.TENSORSET', 'c', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
831+
832+
# def run():
833+
# con = env.getConnection()
834+
# con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'c', 'OUTPUTS', 'd', 'd2')
835+
836+
# t = threading.Thread(target=run)
837+
# t.start()
838+
839+
# con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'OUTPUTS', 'b', 'b2')
840+
841+
# tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES')
842+
# value = tensor[-1][0]
843+
844+
# env.assertEqual(value, 1)
845+
846+
# tensor = con.execute_command('AI.TENSORGET', 'd', 'VALUES')
847+
# value = tensor[-1][0]
848+
849+
# env.assertEqual(value, 1)
845850

846851

847852
def test_set_tensor_multiproc(env):

0 commit comments

Comments
 (0)