Skip to content

Commit cd83c2d

Browse files
support modelrun and run in Dag
1 parent e1572a2 commit cd83c2d

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

redisai/dag.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from redisai import command_builder as builder
77
from redisai.postprocessor import Processor
8+
from deprecated import deprecated
89

910
processor = Processor()
1011

@@ -13,6 +14,11 @@ class Dag:
1314
def __init__(self, load, persist, keys, timeout, executor, readonly=False, postprocess=True):
1415
self.result_processors = []
1516
self.enable_postprocess = True
17+
if load is None and persist is None and keys is None:
18+
raise RuntimeError(
19+
"AI.DAGEXECUTE and AI.DAGEXECUTE_RO commands must contain"
20+
"at least one out of KEYS, LOAD, PERSIST parameters"
21+
)
1622
if readonly:
1723
if persist:
1824
raise RuntimeError(
@@ -42,6 +48,7 @@ def __init__(self, load, persist, keys, timeout, executor, readonly=False, postp
4248

4349
self.commands.append("|>")
4450
self.executor = executor
51+
self.readonly = readonly
4552

4653
def tensorset(
4754
self,
@@ -76,6 +83,15 @@ def tensorget(
7683
)
7784
return self
7885

86+
@deprecated(version="1.2.0", reason="Use modelexecute instead")
87+
def modelrun(
88+
self,
89+
key: AnyStr,
90+
inputs: Union[AnyStr, List[AnyStr]],
91+
outputs: Union[AnyStr, List[AnyStr]],
92+
) -> Any:
93+
return self.modelexecute(key, inputs, outputs)
94+
7995
def modelexecute(
8096
self,
8197
key: AnyStr,
@@ -88,6 +104,10 @@ def modelexecute(
88104
self.result_processors.append(bytes.decode)
89105
return self
90106

107+
@deprecated(version="1.2.0", reason="Use execute instead")
108+
def run(self):
109+
return self.execute()
110+
91111
def execute(self):
92112
commands = self.commands[:-1] # removing the last "|>"
93113
results = self.executor(*commands)

test/test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,21 @@ def setUp(self):
500500
ptmodel = load_model(model_path)
501501
con.modelstore("pt_model", "torch", "cpu", ptmodel, tag="v7.0")
502502

503+
def test_deprecated_dugrun(self):
504+
con = self.get_client()
505+
con.tensorset("a", [2, 3, 2, 3], shape=(2, 2), dtype="float")
506+
con.tensorset("b", [2, 3, 2, 3], shape=(2, 2), dtype="float")
507+
dag = con.dag(load=["a", "b"], persist="output")
508+
dag.modelrun("pt_model", ["a", "b"], ["output"])
509+
dag.tensorget("output")
510+
result = dag.run()
511+
expected = ["OK", np.array([[4.0, 6.0], [4.0, 6.0]], dtype=np.float32)]
512+
result_outside_dag = con.tensorget("output")
513+
self.assertTrue(np.allclose(expected.pop(), result.pop()))
514+
result = dag.run()
515+
self.assertTrue(np.allclose(result_outside_dag, result.pop()))
516+
self.assertEqual(expected, result)
517+
503518
def test_dagrun_with_load(self):
504519
con = self.get_client()
505520
con.tensorset("a", [2, 3, 2, 3], shape=(2, 2), dtype="float")
@@ -549,6 +564,8 @@ def test_dagrun_calling_on_return(self):
549564

550565
def test_dagrun_without_load_and_persist(self):
551566
con = self.get_client()
567+
with self.assertRaises(RuntimeError):
568+
con.dag()
552569

553570
dag = con.dag(load="wrongkey")
554571
with self.assertRaises(ResponseError):

0 commit comments

Comments
 (0)