1
1
import redis
2
+ from functools import wraps
2
3
3
4
from includes import *
4
5
7
8
'''
8
9
9
10
10
- def test_run_mobilenet (env ):
11
- if not TEST_TF :
12
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
13
- return
11
+ def skip_if_no_TF (f ):
12
+ @wraps (f )
13
+ def wrapper (env , * args , ** kwargs ):
14
+ if not TEST_TF :
15
+ env .debugPrint ("skipping {} since TEST_TF=0" .format (
16
+ sys ._getframe ().f_code .co_name ), force = True )
17
+ return
18
+ return f (env , * args , ** kwargs )
19
+ return wrapper
20
+
14
21
22
+ @skip_if_no_TF
23
+ def test_run_mobilenet (env ):
15
24
con = env .getConnection ()
16
25
17
26
input_var = 'input'
@@ -69,11 +78,8 @@ def test_run_mobilenet(env):
69
78
env .assertEqual (data , slave_data )
70
79
71
80
81
+ @skip_if_no_TF
72
82
def test_run_mobilenet_multiproc (env ):
73
- if not TEST_TF :
74
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
75
- return
76
-
77
83
if VALGRIND :
78
84
env .debugPrint ("skipping {} since VALGRIND=1" .format (sys ._getframe ().f_code .co_name ), force = True )
79
85
return
@@ -112,11 +118,8 @@ def test_run_mobilenet_multiproc(env):
112
118
env .assertEqual (data , slave_data )
113
119
114
120
121
+ @skip_if_no_TF
115
122
def test_del_tf_model (env ):
116
- if not TEST_TF :
117
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
118
- return
119
-
120
123
con = env .getConnection ()
121
124
122
125
test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
@@ -157,11 +160,8 @@ def test_del_tf_model(env):
157
160
env .assertEqual ("WRONGTYPE Operation against a key holding the wrong kind of value" , exception .__str__ ())
158
161
159
162
163
+ @skip_if_no_TF
160
164
def test_run_tf_model (env ):
161
- if not TEST_TF :
162
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
163
- return
164
-
165
165
con = env .getConnection ()
166
166
167
167
test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
@@ -217,11 +217,8 @@ def test_run_tf_model(env):
217
217
env .assertFalse (con2 .execute_command ('EXISTS' , 'm' ))
218
218
219
219
220
+ @skip_if_no_TF
220
221
def test_run_tf_model_errors (env ):
221
- if not TEST_TF :
222
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
223
- return
224
-
225
222
con = env .getConnection ()
226
223
227
224
test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
@@ -365,6 +362,7 @@ def test_run_tf_model_errors(env):
365
362
env .assertEqual (type (exception ), redis .exceptions .ResponseError )
366
363
367
364
365
+ @skip_if_no_TF
368
366
def test_run_tf_model_autobatch (env ):
369
367
if not TEST_PT :
370
368
return
@@ -411,11 +409,8 @@ def run():
411
409
env .assertEqual (values , [b'4' , b'9' , b'4' , b'9' ])
412
410
413
411
412
+ @skip_if_no_TF
414
413
def test_tensorflow_modelinfo (env ):
415
- if not TEST_TF :
416
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
417
- return
418
-
419
414
con = env .getConnection ()
420
415
421
416
test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
@@ -466,11 +461,8 @@ def test_tensorflow_modelinfo(env):
466
461
env .assertEqual (info_dict_0 ['ERRORS' ], 0 )
467
462
468
463
464
+ @skip_if_no_TF
469
465
def test_tensorflow_modelrun_disconnect (env ):
470
- if not TEST_TF :
471
- env .debugPrint ("skipping {} since TEST_TF=0" .format (sys ._getframe ().f_code .co_name ), force = True )
472
- return
473
-
474
466
red = env .getConnection ()
475
467
476
468
test_data_path = os .path .join (os .path .dirname (__file__ ), 'test_data' )
@@ -493,3 +485,55 @@ def test_tensorflow_modelrun_disconnect(env):
493
485
494
486
ret = send_and_disconnect (('AI.MODELRUN' , 'm' , 'INPUTS' , 'a' , 'b' , 'OUTPUTS' , 'c' ), red )
495
487
env .assertEqual (ret , None )
488
+
489
+
490
+ @skip_if_no_TF
491
+ def test_with_batch_and_minbatch (env ):
492
+ con = env .getConnection ()
493
+ batch_size = 2
494
+ minbatch_size = 2
495
+ model_name = 'model'
496
+ another_model_name = 'another_model'
497
+ inputvar = 'input'
498
+ outputvar = 'MobilenetV2/Predictions/Reshape_1'
499
+
500
+ model_pb , labels , img = load_mobilenet_test_data ()
501
+
502
+ con .execute_command ('AI.MODELSET' , model_name , 'TF' , DEVICE ,
503
+ 'BATCHSIZE' , batch_size , 'MINBATCHSIZE' , minbatch_size ,
504
+ 'INPUTS' , inputvar ,
505
+ 'OUTPUTS' , outputvar ,
506
+ model_pb )
507
+ con .execute_command ('AI.TENSORSET' , 'input' ,
508
+ 'FLOAT' , 1 , img .shape [1 ], img .shape [0 ], img .shape [2 ],
509
+ 'BLOB' , img .tobytes ())
510
+
511
+ def run (name = model_name ):
512
+ con .execute_command ('AI.MODELRUN' , name ,
513
+ 'INPUTS' , 'input' , 'OUTPUTS' , 'output' )
514
+
515
+ # Running thrice since minbatchsize = 2
516
+ threading .Thread (target = run ).start ()
517
+ threading .Thread (target = run ).start ()
518
+ threading .Thread (target = run ).start ()
519
+ time .sleep (1 )
520
+
521
+ # This is where the problem. If we set any new model (Note that the model
522
+ # name has changed), then the subsequent requests fails
523
+ con .execute_command ('AI.MODELSET' , another_model_name , 'TF' , DEVICE ,
524
+ 'BATCHSIZE' , batch_size , 'MINBATCHSIZE' , minbatch_size ,
525
+ 'INPUTS' , inputvar ,
526
+ 'OUTPUTS' , outputvar ,
527
+ model_pb )
528
+
529
+ threading .Thread (target = run , args = (another_model_name ,)).start ()
530
+ threading .Thread (target = run , args = (another_model_name ,)).start ()
531
+
532
+ dtype , shape , data = con .execute_command ('AI.TENSORGET' , 'output' , 'BLOB' )
533
+ dtype_map = {b'FLOAT' : np .float32 }
534
+ tensor = np .frombuffer (data , dtype = dtype_map [dtype ]).reshape (shape )
535
+ label_id = np .argmax (tensor ) - 1
536
+
537
+ _ , label = labels [str (label_id )]
538
+
539
+ env .assertEqual (label , 'giant_panda' )
0 commit comments