@@ -801,47 +801,52 @@ def test_run_tflite_model(env):
801
801
# TODO: Autobatch is tricky with TFLITE because TFLITE expects a fixed batch
802
802
# size. At least we should constrain MINBATCHSIZE according to the
803
803
# hard-coded dims in the tflite model.
804
- # def test_run_tflite_model_autobatch(env):
805
- # if not TEST_PT:
806
- # return
807
- #
808
- # con = env.getConnection()
809
- #
810
- # test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
811
- # model_filename = os.path.join(test_data_path, 'mnist_model_quant.tflite')
812
- # sample_filename = os.path.join(test_data_path, 'one.raw')
813
- #
814
- # with open(model_filename, 'rb') as f:
815
- # model_pb = f.read()
816
- #
817
- # with open(sample_filename, 'rb') as f:
818
- # sample_raw = f.read()
819
- #
820
- # ret = con.execute_command('AI.MODELSET', 'm', 'TFLITE', 'CPU',
821
- # 'BATCHSIZE', 2, 'MINBATCHSIZE', 2, model_pb)
822
- # env.assertEqual(ret, b'OK')
823
- #
824
- # con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
825
- # con.execute_command('AI.TENSORSET', 'c', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
826
- #
827
- # def run():
828
- # con = env.getConnection()
829
- # con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'c', 'OUTPUTS', 'd', 'd2')
830
- #
831
- # t = threading.Thread(target=run)
832
- # t.start()
833
- #
834
- # con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'OUTPUTS', 'b', 'b2')
835
- #
836
- # tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES')
837
- # value = tensor[-1][0]
838
- #
839
- # env.assertEqual(value, 1)
840
- #
841
- # tensor = con.execute_command('AI.TENSORGET', 'd', 'VALUES')
842
- # value = tensor[-1][0]
843
- #
844
- # env.assertEqual(value, 1)
804
+ def test_run_tflite_model_autobatch (env ):
805
+ if not TEST_PT :
806
+ return
807
+
808
+ con = env .getConnection ()
809
+
810
+ test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
811
+ model_filename = os .path .join (test_data_path , 'mnist_model_quant.tflite' )
812
+ sample_filename = os .path .join (test_data_path , 'one.raw' )
813
+
814
+ with open (model_filename , 'rb' ) as f :
815
+ model_pb = f .read ()
816
+
817
+ with open (sample_filename , 'rb' ) as f :
818
+ sample_raw = f .read ()
819
+
820
+ try :
821
+ ret = con .execute_command ('AI.MODELSET' , 'm' , 'TFLITE' , 'CPU' ,
822
+ 'BATCHSIZE' , 2 , 'MINBATCHSIZE' , 2 , model_pb )
823
+ except Exception as e :
824
+ exception = e
825
+ env .assertEqual (type (exception ), redis .exceptions .ResponseError )
826
+
827
+ # env.assertEqual(ret, b'OK')
828
+
829
+ # con.execute_command('AI.TENSORSET', 'a', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
830
+ # con.execute_command('AI.TENSORSET', 'c', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)
831
+
832
+ # def run():
833
+ # con = env.getConnection()
834
+ # con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'c', 'OUTPUTS', 'd', 'd2')
835
+
836
+ # t = threading.Thread(target=run)
837
+ # t.start()
838
+
839
+ # con.execute_command('AI.MODELRUN', 'm', 'INPUTS', 'a', 'OUTPUTS', 'b', 'b2')
840
+
841
+ # tensor = con.execute_command('AI.TENSORGET', 'b', 'VALUES')
842
+ # value = tensor[-1][0]
843
+
844
+ # env.assertEqual(value, 1)
845
+
846
+ # tensor = con.execute_command('AI.TENSORGET', 'd', 'VALUES')
847
+ # value = tensor[-1][0]
848
+
849
+ # env.assertEqual(value, 1)
845
850
846
851
847
852
def test_set_tensor_multiproc (env ):
0 commit comments