Skip to content

Commit b60dfe7

Browse files
author
hhsecond
committed
test cases for crash test
1 parent 9ab3cec commit b60dfe7

File tree

1 file changed

+72
-28
lines changed

1 file changed

+72
-28
lines changed

test/tests_tensorflow.py

Lines changed: 72 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import redis
2+
from functools import wraps
23

34
from includes import *
45

@@ -7,11 +8,19 @@
78
'''
89

910

10-
def test_run_mobilenet(env):
11-
if not TEST_TF:
12-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
13-
return
11+
def skip_if_no_TF(f):
12+
@wraps(f)
13+
def wrapper(env, *args, **kwargs):
14+
if not TEST_TF:
15+
env.debugPrint("skipping {} since TEST_TF=0".format(
16+
sys._getframe().f_code.co_name), force=True)
17+
return
18+
return f(env, *args, **kwargs)
19+
return wrapper
20+
1421

22+
@skip_if_no_TF
23+
def test_run_mobilenet(env):
1524
con = env.getConnection()
1625

1726
input_var = 'input'
@@ -69,11 +78,8 @@ def test_run_mobilenet(env):
6978
env.assertEqual(data, slave_data)
7079

7180

81+
@skip_if_no_TF
7282
def test_run_mobilenet_multiproc(env):
73-
if not TEST_TF:
74-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
75-
return
76-
7783
if VALGRIND:
7884
env.debugPrint("skipping {} since VALGRIND=1".format(sys._getframe().f_code.co_name), force=True)
7985
return
@@ -112,11 +118,8 @@ def test_run_mobilenet_multiproc(env):
112118
env.assertEqual(data, slave_data)
113119

114120

121+
@skip_if_no_TF
115122
def test_del_tf_model(env):
116-
if not TEST_TF:
117-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
118-
return
119-
120123
con = env.getConnection()
121124

122125
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
@@ -157,11 +160,8 @@ def test_del_tf_model(env):
157160
env.assertEqual("WRONGTYPE Operation against a key holding the wrong kind of value", exception.__str__())
158161

159162

163+
@skip_if_no_TF
160164
def test_run_tf_model(env):
161-
if not TEST_TF:
162-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
163-
return
164-
165165
con = env.getConnection()
166166

167167
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
@@ -217,11 +217,8 @@ def test_run_tf_model(env):
217217
env.assertFalse(con2.execute_command('EXISTS', 'm'))
218218

219219

220+
@skip_if_no_TF
220221
def test_run_tf_model_errors(env):
221-
if not TEST_TF:
222-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
223-
return
224-
225222
con = env.getConnection()
226223

227224
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
@@ -365,6 +362,7 @@ def test_run_tf_model_errors(env):
365362
env.assertEqual(type(exception), redis.exceptions.ResponseError)
366363

367364

365+
@skip_if_no_TF
368366
def test_run_tf_model_autobatch(env):
369367
if not TEST_PT:
370368
return
@@ -411,11 +409,8 @@ def run():
411409
env.assertEqual(values, [b'4', b'9', b'4', b'9'])
412410

413411

412+
@skip_if_no_TF
414413
def test_tensorflow_modelinfo(env):
415-
if not TEST_TF:
416-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
417-
return
418-
419414
con = env.getConnection()
420415

421416
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
@@ -466,11 +461,8 @@ def test_tensorflow_modelinfo(env):
466461
env.assertEqual(info_dict_0['ERRORS'], 0)
467462

468463

464+
@skip_if_no_TF
469465
def test_tensorflow_modelrun_disconnect(env):
470-
if not TEST_TF:
471-
env.debugPrint("skipping {} since TEST_TF=0".format(sys._getframe().f_code.co_name), force=True)
472-
return
473-
474466
red = env.getConnection()
475467

476468
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
@@ -493,3 +485,55 @@ def test_tensorflow_modelrun_disconnect(env):
493485

494486
ret = send_and_disconnect(('AI.MODELRUN', 'm', 'INPUTS', 'a', 'b', 'OUTPUTS', 'c'), red)
495487
env.assertEqual(ret, None)
488+
489+
490+
@skip_if_no_TF
491+
def test_with_batch_and_minbatch(env):
492+
con = env.getConnection()
493+
batch_size = 2
494+
minbatch_size = 2
495+
model_name = 'model'
496+
another_model_name = 'another_model'
497+
inputvar = 'input'
498+
outputvar = 'MobilenetV2/Predictions/Reshape_1'
499+
500+
model_pb, labels, img = load_mobilenet_test_data()
501+
502+
con.execute_command('AI.MODELSET', model_name, 'TF', DEVICE,
503+
'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size,
504+
'INPUTS', inputvar,
505+
'OUTPUTS', outputvar,
506+
model_pb)
507+
con.execute_command('AI.TENSORSET', 'input',
508+
'FLOAT', 1, img.shape[1], img.shape[0], img.shape[2],
509+
'BLOB', img.tobytes())
510+
511+
def run(name=model_name):
512+
con.execute_command('AI.MODELRUN', name,
513+
'INPUTS', 'input', 'OUTPUTS', 'output')
514+
515+
# Running thrice since minbatchsize = 2
516+
threading.Thread(target=run).start()
517+
threading.Thread(target=run).start()
518+
threading.Thread(target=run).start()
519+
time.sleep(1)
520+
521+
# This is where the problem. If we set any new model (Note that the model
522+
# name has changed), then the subsequent requests fails
523+
con.execute_command('AI.MODELSET', another_model_name, 'TF', DEVICE,
524+
'BATCHSIZE', batch_size, 'MINBATCHSIZE', minbatch_size,
525+
'INPUTS', inputvar,
526+
'OUTPUTS', outputvar,
527+
model_pb)
528+
529+
threading.Thread(target=run, args=(another_model_name,)).start()
530+
threading.Thread(target=run, args=(another_model_name,)).start()
531+
532+
dtype, shape, data = con.execute_command('AI.TENSORGET', 'output', 'BLOB')
533+
dtype_map = {b'FLOAT': np.float32}
534+
tensor = np.frombuffer(data, dtype=dtype_map[dtype]).reshape(shape)
535+
label_id = np.argmax(tensor) - 1
536+
537+
_, label = labels[str(label_id)]
538+
539+
env.assertEqual(label, 'giant_panda')

0 commit comments

Comments
 (0)