2
2
import sys
3
3
from io import StringIO
4
4
from unittest import TestCase
5
+ from skimage .io import imread
6
+ from skimage .transform import resize
5
7
6
8
import numpy as np
7
9
from ml2rt import load_model
12
14
DEBUG = False
13
15
tf_graph = "graph.pb"
14
16
torch_graph = "pt-minimal.pt"
17
+ dog_img = "dog.jpg"
15
18
16
19
17
20
class Capturing (list ):
@@ -32,6 +35,20 @@ def bar(a, b):
32
35
return a + b
33
36
"""
34
37
38
+ data_processing_script = r"""
39
+ def pre_process_3ch(image):
40
+ return image.float().div(255).unsqueeze(0)
41
+
42
+ def pre_process_4ch(image):
43
+ return image.float().div(255)[:,:,:-1].contiguous().unsqueeze(0)
44
+
45
+ def post_process(output):
46
+ # tf model has 1001 classes, hence negative 1
47
+ return output.max(1)[1] - 1
48
+
49
+ def ensemble(output0, output1):
50
+ return (output0 + output1) * 0.5
51
+ """
35
52
36
53
class RedisAITestBase (TestCase ):
37
54
def setUp (self ):
@@ -492,6 +509,16 @@ def test_debug(self):
492
509
self .assertEqual (["AI.TENSORSET x FLOAT 4 VALUES 2 3 4 5" ], output )
493
510
494
511
512
+ def load_image ():
513
+ image_filename = os .path .join (MODEL_DIR , dog_img )
514
+ img_height , img_width = 224 , 224
515
+
516
+ img = imread (image_filename )
517
+ img = resize (img , (img_height , img_width ), mode = 'constant' , anti_aliasing = True )
518
+ img = img .astype (np .uint8 )
519
+ return img
520
+
521
+
495
522
class DagTestCase (RedisAITestBase ):
496
523
def setUp (self ):
497
524
super ().setUp ()
@@ -515,6 +542,27 @@ def test_deprecated_dugrun(self):
515
542
self .assertTrue (np .allclose (result_outside_dag , result .pop ()))
516
543
self .assertEqual (expected , result )
517
544
545
+ """
546
+ def test_dagexecute_modelexecute_with_scriptexecute(self):
547
+ con = self.get_client()
548
+ script_name = 'imagenet_script:{1}'
549
+ model_name = 'imagenet_model:{1}'
550
+
551
+ img = load_image()
552
+ model_path = os.path.join(MODEL_DIR, "resnet50.pb")
553
+ model = load_model(model_path)
554
+ con.scriptset(script_name, 'cpu', data_processing_script)
555
+ con.modelstore(model_name, 'TF', 'cpu', model, inputs='images', outputs='output')
556
+
557
+ dag = con.dag(persist='output:{1}')
558
+ dag.tensorset('image:{1}', tensor=img, shape=(img.shape[1], img.shape[0]), dtype='UINT8')
559
+ dag.scriptexecute(script_name, 'pre_process_3ch', keys=[], inputs='image:{1}', outputs='temp_key1')
560
+ dag.modelexecute(model_name, inputs='temp_key1', outputs='temp_key2')
561
+ dag.scriptexecute(script_name, 'post_process', keys=[], inputs='temp_key2', outputs='output:{1}')
562
+ ret = dag.execute()
563
+ self.assertEqual(['OK', 'OK', 'OK', 'OK'], ret)
564
+ """
565
+
518
566
def test_dagexecute_with_load (self ):
519
567
con = self .get_client ()
520
568
con .tensorset ("a" , [2 , 3 , 2 , 3 ], shape = (2 , 2 ), dtype = "float" )
@@ -608,6 +656,10 @@ def test_dagexecuteRO(self):
608
656
with self .assertRaises (RuntimeError ):
609
657
con .dag (load = ["a" , "b" ], persist = "output" , readonly = True )
610
658
dag = con .dag (load = ["a" , "b" ], readonly = True )
659
+ """
660
+ with self.assertRaises(RuntimeError):
661
+ dag.scriptexecute()
662
+ """
611
663
dag .modelexecute ("pt_model" , ["a" , "b" ], ["output" ])
612
664
dag .tensorget ("output" )
613
665
result = dag .execute ()
0 commit comments