Skip to content

Commit ebce97d

Browse files
add scriptexecute support with test
1 parent d008db9 commit ebce97d

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

redisai/dag.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,24 @@ def modelexecute(
104104
self.result_processors.append(bytes.decode)
105105
return self
106106

107+
def scriptexecute(
108+
self,
109+
key: AnyStr,
110+
function: str,
111+
keys: Union[AnyStr, Sequence[AnyStr]],
112+
inputs: Union[AnyStr, Sequence[Union[AnyStr, Sequence[AnyStr]]]] = None,
113+
outputs: Union[AnyStr, List[AnyStr]] = None,
114+
) -> Any:
115+
if self.readonly:
116+
raise RuntimeError(
117+
"AI.SCRIPTEXECUTE cannot be used in readonly mode"
118+
)
119+
args = builder.scriptexecute(key, function, keys, inputs, outputs, None)
120+
self.commands.extend(args)
121+
self.commands.append("|>")
122+
self.result_processors.append(bytes.decode)
123+
return self
124+
107125
@deprecated(version="1.2.0", reason="Use execute instead")
108126
def run(self):
109127
return self.execute()

test-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ nose
66
codecov
77
numpy
88
ml2rt
9-
deprecated
9+
deprecated
10+
scikit-image

test/test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import sys
33
from io import StringIO
44
from unittest import TestCase
5+
from skimage.io import imread
6+
from skimage.transform import resize
57

68
import numpy as np
79
from ml2rt import load_model
@@ -12,6 +14,7 @@
1214
DEBUG = False
1315
tf_graph = "graph.pb"
1416
torch_graph = "pt-minimal.pt"
17+
dog_img = "dog.jpg"
1518

1619

1720
class Capturing(list):
@@ -32,6 +35,20 @@ def bar(a, b):
3235
return a + b
3336
"""
3437

38+
data_processing_script = r"""
39+
def pre_process_3ch(image):
40+
return image.float().div(255).unsqueeze(0)
41+
42+
def pre_process_4ch(image):
43+
return image.float().div(255)[:,:,:-1].contiguous().unsqueeze(0)
44+
45+
def post_process(output):
46+
# tf model has 1001 classes, hence negative 1
47+
return output.max(1)[1] - 1
48+
49+
def ensemble(output0, output1):
50+
return (output0 + output1) * 0.5
51+
"""
3552

3653
class RedisAITestBase(TestCase):
3754
def setUp(self):
@@ -492,6 +509,16 @@ def test_debug(self):
492509
self.assertEqual(["AI.TENSORSET x FLOAT 4 VALUES 2 3 4 5"], output)
493510

494511

512+
def load_image():
513+
image_filename = os.path.join(MODEL_DIR, dog_img)
514+
img_height, img_width = 224, 224
515+
516+
img = imread(image_filename)
517+
img = resize(img, (img_height, img_width), mode='constant', anti_aliasing=True)
518+
img = img.astype(np.uint8)
519+
return img
520+
521+
495522
class DagTestCase(RedisAITestBase):
496523
def setUp(self):
497524
super().setUp()
@@ -515,6 +542,27 @@ def test_deprecated_dugrun(self):
515542
self.assertTrue(np.allclose(result_outside_dag, result.pop()))
516543
self.assertEqual(expected, result)
517544

545+
"""
546+
def test_dagexecute_modelexecute_with_scriptexecute(self):
547+
con = self.get_client()
548+
script_name = 'imagenet_script:{1}'
549+
model_name = 'imagenet_model:{1}'
550+
551+
img = load_image()
552+
model_path = os.path.join(MODEL_DIR, "resnet50.pb")
553+
model = load_model(model_path)
554+
con.scriptset(script_name, 'cpu', data_processing_script)
555+
con.modelstore(model_name, 'TF', 'cpu', model, inputs='images', outputs='output')
556+
557+
dag = con.dag(persist='output:{1}')
558+
dag.tensorset('image:{1}', tensor=img, shape=(img.shape[1], img.shape[0]), dtype='UINT8')
559+
dag.scriptexecute(script_name, 'pre_process_3ch', keys=[], inputs='image:{1}', outputs='temp_key1')
560+
dag.modelexecute(model_name, inputs='temp_key1', outputs='temp_key2')
561+
dag.scriptexecute(script_name, 'post_process', keys=[], inputs='temp_key2', outputs='output:{1}')
562+
ret = dag.execute()
563+
self.assertEqual(['OK', 'OK', 'OK', 'OK'], ret)
564+
"""
565+
518566
def test_dagexecute_with_load(self):
519567
con = self.get_client()
520568
con.tensorset("a", [2, 3, 2, 3], shape=(2, 2), dtype="float")
@@ -608,6 +656,10 @@ def test_dagexecuteRO(self):
608656
with self.assertRaises(RuntimeError):
609657
con.dag(load=["a", "b"], persist="output", readonly=True)
610658
dag = con.dag(load=["a", "b"], readonly=True)
659+
"""
660+
with self.assertRaises(RuntimeError):
661+
dag.scriptexecute()
662+
"""
611663
dag.modelexecute("pt_model", ["a", "b"], ["output"])
612664
dag.tensorget("output")
613665
result = dag.execute()

0 commit comments

Comments
 (0)