Skip to content

Commit 10d857f

Browse files
authored
Let ai.scriptget and ai.modelget commands run successfully without optional args (#791)
* Align script get command, so it will return both source and meta if no optional argument is specified. * Documentation fixes * Add default behaviour for AI.MODELGET + documentation
1 parent 33c2b59 commit 10d857f

File tree

5 files changed

+100
-84
lines changed

5 files changed

+100
-84
lines changed

docs/commands.md

Lines changed: 31 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ AI.MODELGET <key> [META] [BLOB]
222222
_Arguments
223223

224224
* **key**: the model's key name
225-
* **META**: will return the model's meta information on backend, device, tag and batching parameters
226-
* **BLOB**: will return the model's blob containing the serialized model
225+
* **META**: will return only the model's meta information on backend, device, tag and batching parameters
226+
* **BLOB**: will return only the model's blob containing the serialized model
227227

228228
_Return_
229229

@@ -237,7 +237,7 @@ An array of alternating key-value pairs as follows:
237237
1. **INPUTS**: array reply with one or more names of the model's input nodes (applicable only for TensorFlow models)
238238
1. **OUTPUTS**: array reply with one or more names of the model's output nodes (applicable only for TensorFlow models)
239239
1. **MINBATCHTIMEOUT**: The time in milliseconds for which the engine will wait before executing a request to run the model, when the number of incoming requests is lower than `MINBATCHSIZE`. When `MINBATCHTIMEOUT` is 0, the engine will not run the model before it receives at least `MINBATCHSIZE` requests.
240-
1. **BLOB**: a blob containing the serialized model (when called with the `BLOB` argument) as a String. If the size of the serialized model exceeds `MODEL_CHUNK_SIZE` (see `AI.CONFIG` command), then an array of chunks is returned. The full serialized model can be obtained by concatenating the chunks.
240+
1. **BLOB**: a blob containing the serialized model as a String. If the size of the serialized model exceeds `MODEL_CHUNK_SIZE` (see `AI.CONFIG` command), then an array of chunks is returned. The full serialized model can be obtained by concatenating the chunks.
241241

242242
**Examples**
243243

@@ -415,7 +415,7 @@ The **`AI.SCRIPTSTORE`** command stores a [TorchScript](https://pytorch.org/docs
415415
**Redis API**
416416

417417
```
418-
AI.SCRIPTSTORE <key> <device> [TAG tag] ENTRY_POINTS <entry_point_amoint> <entry_point> [<entry_point>...] SOURCE "<script>"
418+
AI.SCRIPTSTORE <key> <device> [TAG tag] ENTRY_POINTS <entry_points_count> <entry_point> [<entry_point>...] SOURCE "<script>"
419419
```
420420

421421
_Arguments_
@@ -427,7 +427,7 @@ _Arguments_
427427
* **CPU**: a CPU device
428428
* **GPU**: a GPU device
429429
* **GPU:0**, ..., **GPU:n**: a specific GPU device on a multi-GPU system
430-
* **ENTRY_POINTS** A list of entry points to be used in the script. Each entry point should have the signature of `def entry_point(tensors: List[Tensor], keys: List[str], args: List[str])`. The purpose of each list is as follows:
430+
* **ENTRY_POINTS** A list of function names in the script to be used as entry points upon execution. Each entry point should have the signature of `def entry_point(tensors: List[Tensor], keys: List[str], args: List[str])`. The purpose of each list is as follows:
431431
* `tensors`: A list holding the input tensors to the function.
432432
* `keys`: A list of keys that the torch script is about to preform read/write operations on.
433433
* `args`: A list of additional arguments to the function. If the desired argument is not from type string, it is up to the caller to cast it to the right type, within the script.
@@ -510,18 +510,18 @@ AI.SCRIPTGET <key> [META] [SOURCE]
510510
_Arguments_
511511

512512
* **key**: the script's key name
513-
* **META**: will return the script's meta information on device and tag
514-
* **SOURCE**: will return a string containing [TorchScript](https://pytorch.org/docs/stable/jit.html) source code
513+
* **META**: will return only the script's meta information on device, tag and entry points.
514+
* **SOURCE**: will return only the string containing [TorchScript](https://pytorch.org/docs/stable/jit.html) source code
515515

516516
_Return_
517517

518518
An array with alternating entries that represent the following key-value pairs:
519-
!!!!The command returns a list of key-value strings, namely `DEVICE device TAG tag [SOURCE source]`.
519+
!!!!The command returns a list of key-value strings, namely `DEVICE device TAG tag ENTRY_POINTS [entry_point ...] SOURCE source`.
520520

521521
1. **DEVICE**: the script's device as a String
522522
2. **TAG**: the scripts's tag as a String
523523
3. **SOURCE**: the script's source code as a String
524-
4. **ENTRY_POINTS** will return an array containing the script entry points
524+
4. **ENTRY_POINTS** will return an array containing the script entry point functions
525525

526526
**Examples**
527527

@@ -570,7 +570,7 @@ OK
570570

571571
## AI.SCRIPTEXECUTE
572572

573-
The **`AI.SCRIPTEXECUTE`** command runs a script stored as a key's value on its specified device. It a list of keys, input tensors and addtional script args.
573+
The **`AI.SCRIPTEXECUTE`** command runs a script stored as a key's value on its specified device. It receives a list of Redis keys, a list of input tensors and an additional list of arguments to be used in the script.
574574

575575
The run request is put in a queue and is executed asynchronously by a worker thread. The client that had issued the run request is blocked until the script run is completed. When needed, tensors data is automatically copied to the device prior to execution.
576576

@@ -583,25 +583,25 @@ A `TIMEOUT t` argument can be specified to cause a request to be removed from th
583583

584584
```
585585
AI.SCRIPTEXECUTE <key> <function>
586-
[KEYS n <key> [keys...]]
587-
[INPUTS m <input> [input ...]]
588-
[ARGS k <arg> [arg...]]
589-
[OUTPUTS k <output> [output ...] [TIMEOUT t]]+
586+
[KEYS <keys_count> <key> [keys...]]
587+
[INPUTS <input_count> <input> [input ...]]
588+
[ARGS <args_count> <arg> [arg...]]
589+
[OUTPUTS <outputs_count> <output> [output ...]]
590+
[TIMEOUT t]
590591
```
591592

592593
_Arguments_
593594

594-
* **key**: the script's key name
595-
* **function**: the name of the function to run
596-
* **KEYS**: Either a squence of key names that the script will access before, during and after its execution, or a tag which all those keys share.
597-
* **INPUTS**: Denotes the beginning of the input parameters list, followed by its length and one or more input tensors.
598-
* **ARGS**: A list additional arguments that a user can send to the script. All args are sent as strings, but can be casted to other types supported by torch script, such as `int`, or `float`.
599-
595+
* **key**: the script's key name.
596+
* **function**: the name of the entry point function to run.
597+
* **KEYS**: Denotes the beginning of a list of Redis key names that the script will access to during its execution, for both read and/or write operations.
598+
* **INPUTS**: Denotes the beginning of the input tensors list, followed by its length and one or more input tensors.
599+
* **ARGS**: Denotes the beginning of a list of additional arguments that a user can send to the script. All args are sent as strings, but can be casted to other types supported by torch script, such as `int`, or `float`.
600600
* **OUTPUTS**: denotes the beginning of the output tensors keys' list, followed by its length and one or more key names.
601601
* **TIMEOUT**: the time (in ms) after which the client is unblocked and a `TIMEDOUT` string is returned
602602

603603
Note:
604-
Either `KEYS` or `INPUTS` scopes should be provided this command (one or both scopes are acceptable). Those scopes indicate keyspace access and such, the right shard to execute the command at. Redis will verify that all potional key accesses are done to the right shard.
604+
Either `KEYS` or `INPUTS` scopes should be provided this command (one or both scopes are acceptable). Those scopes indicate keyspace access and such, the right shard to execute the command at. Redis will verify that all potential key accesses are done to the right shard.
605605

606606
_Return_
607607

@@ -611,27 +611,12 @@ A simple 'OK' string, a simple `TIMEDOUT` string, or an error.
611611

612612
The following is an example of running the previously-created 'myscript' on two input tensors:
613613

614-
```
615-
redis> AI.TENSORSET mytensor1 FLOAT 1 VALUES 40
616-
OK
617-
redis> AI.TENSORSET mytensor2 FLOAT 1 VALUES 2
618-
OK
619-
redis> AI.SCRIPTEXECUTE myscript addtwo KEYS 3 mytensor1 mytensor2 result INPUTS 2 mytensor1 mytensor2 OUTPUTS 1 result
620-
OK
621-
redis> AI.TENSORGET result VALUES
622-
1) FLOAT
623-
2) 1) (integer) 1
624-
3) 1) "42"
625-
```
626-
627-
Note: The above command could be executed with a shorter version, given all the keys are tagged with the same tag:
628-
629614
```
630615
redis> AI.TENSORSET mytensor1{tag} FLOAT 1 VALUES 40
631616
OK
632617
redis> AI.TENSORSET mytensor2{tag} FLOAT 1 VALUES 2
633618
OK
634-
redis> AI.SCRIPTEXECUTE myscript{tag} addtwo KEYS 1 {tag} INPUTS 2 mytensor1{tag} mytensor2{tag} OUTPUTS 1 result{tag}
619+
redis> AI.SCRIPTEXECUTE myscript{tag} addtwo INPUTS 2 mytensor1{tag} mytensor2{tag} OUTPUTS 1 result{tag}
635620
OK
636621
redis> AI.TENSORGET result{tag} VALUES
637622
1) FLOAT
@@ -652,18 +637,18 @@ redis> AI.TENSORSET mytensor2{tag} FLOAT 1 VALUES 1
652637
OK
653638
redis> AI.TENSORSET mytensor3{tag} FLOAT 1 VALUES 1
654639
OK
655-
redis> AI.SCRIPTEXECUTE myscript{tag} addn keys 1 {tag} INPUTS 3 mytensor1{tag} mytensor2{tag} mytensor3{tag} OUTPUTS 1 result{tag}
640+
redis> AI.SCRIPTEXECUTE myscript{tag} addn INPUTS 3 mytensor1{tag} mytensor2{tag} mytensor3{tag} OUTPUTS 1 result{tag}
656641
OK
657642
redis> AI.TENSORGET result{tag} VALUES
658643
1) FLOAT
659644
2) 1) (integer) 1
660645
3) 1) "42"
661646
```
662647

663-
Note: for the time being, as `AI.SCRIPTSET` is still avialable to use, `AI.SCRIPTEXECUTE` still supports running functions that are part of scripts stored with `AI.SCRIPTSET` or imported from old RDB/AOF files. Meaning calling `AI.SCRIPTEXECUTE` over a function without the dedicated signature of `(tensors: List[Tensor], keys: List[str], args: List[str]` will yield a "best effort" execution to match the deprecated API `AI.SCRIPTRUN` function execution. This will map `INPUTS` tensors only, to their counterpart input arguments in the function, according to the order which they apear.
648+
Note: for the time being, as `AI.SCRIPTSET` is still available to use, `AI.SCRIPTEXECUTE` still supports running functions that are part of scripts stored with `AI.SCRIPTSET` or imported from old RDB/AOF files. Meaning calling `AI.SCRIPTEXECUTE` over a function without the dedicated signature of `(tensors: List[Tensor], keys: List[str], args: List[str]` will yield a "best effort" execution to match the deprecated API `AI.SCRIPTRUN` function execution. This will map `INPUTS` tensors only, to their counterpart input arguments in the function, according to the order which they appear.
664649

665650
### Redis Commands support.
666-
In RedisAI TorchScript now supports simple (non-blocking) Redis commnands via the `redis.execute` API. The following script gets a key name (`x{1}`), and an `int` value (3). First, the script `SET`s the value in the key. Next, the script `GET`s the value back from the key, and sets it in a tensor which is eventually stored under the key 'y{1}'. Note that the inputs are `str` and `int`. The script sets and gets the value and set it into a tensor.
651+
In RedisAI TorchScript now supports simple (non-blocking) Redis commands via the `redis.execute` API. The following script gets a key name (`x{1}`), and an `int` value (3). First, the script `SET`s the value in the key. Next, the script `GET`s the value back from the key, and sets it in a tensor which is eventually stored under the key 'y{1}'. Note that the inputs are `str` and `int`. The script sets and gets the value and set it into a tensor.
667652

668653
```
669654
def redis_int_to_tensor(redis_value: int):
@@ -692,13 +677,13 @@ The command receives 3 inputs:
692677
Return value - the model execution output tensors (List of torch.Tensor)
693678
The following script creates two tensors, and executes the (tensorflow) model which is stored under the name 'tf_mul{1}' with these two tensors as inputs.
694679
```
695-
def test_model_execute(keys:List[str]):
680+
def test_model_execute(tensors: List[Tensor], keys: List[str], args: List[str]):
696681
a = torch.tensor([[2.0, 3.0], [2.0, 3.0]])
697682
b = torch.tensor([[2.0, 3.0], [2.0, 3.0]])
698683
return redisAI.model_execute(keys[0], [a, b], 1) # assume keys[0] is the model name stored in RedisAI.
699684
```
700685
```
701-
redis> AI.SCRIPTEXECUTE redis_scripts{1} test_model_execute KEYS 1 {1} LIST_INPUTS 1 tf_mul{1} OUTPUTS 1 y{1}
686+
redis> AI.SCRIPTEXECUTE redis_scripts{1} test_model_execute KEYS 1 tf_mul{1} OUTPUTS 1 y{1}
702687
OK
703688
redis> AI.TENSORGET y{1} VALUES
704689
1) (float) 4
@@ -833,9 +818,9 @@ A `TIMEOUT t` argument can be specified to cause a request to be removed from th
833818
**Redis API**
834819

835820
```
836-
AI.DAGEXECUTE [[LOAD <n> <key-1> <key-2> ... <key-n>] |
837-
[PERSIST <n> <key-1> <key-2> ... <key-n>] |
838-
[ROUTING <routing_tag>]]
821+
AI.DAGEXECUTE [LOAD <n> <key-1> <key-2> ... <key-n>]
822+
[PERSIST <n> <key-1> <key-2> ... <key-n>]
823+
[ROUTING <routing_tag>]
839824
[TIMEOUT t]
840825
|> <command> [|> command ...]
841826
```
@@ -844,7 +829,7 @@ _Arguments_
844829

845830
* **LOAD**: denotes the beginning of the input tensors keys' list, followed by the number of keys, and one or more key names
846831
* **PERSIST**: denotes the beginning of the output tensors keys' list, followed by the number of keys, and one or more key names
847-
* **ROUTING**: denotes the a key name or a tag that will assist in routing the dag execution command to the right shard. Redis will verify that all potential key accesses are done to within the target shard.
832+
* **ROUTING**: denotes a key to be used in the DAG or a tag that will assist in routing the dag execution command to the right shard. Redis will verify that all potential key accesses are done to within the target shard.
848833

849834
_While each of the LOAD, PERSIST and ROUTING sections are optional (and may appear at most once in the command), the command must contain **at least one** of these 3 keywords._
850835
* **TIMEOUT**: an optional argument, denotes the time (in ms) after which the client is unblocked and a `TIMEDOUT` string is returned

src/redisai.c

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -416,31 +416,24 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
416416
return REDISMODULE_ERR;
417417
}
418418

419-
int meta = 0;
420-
int blob = 0;
419+
int meta = false;
420+
int blob = false;
421421
for (int i = 2; i < argc; i++) {
422422
const char *optstr = RedisModule_StringPtrLen(argv[i], NULL);
423423
if (!strcasecmp(optstr, "META")) {
424-
meta = 1;
424+
meta = true;
425425
} else if (!strcasecmp(optstr, "BLOB")) {
426-
blob = 1;
426+
blob = true;
427427
}
428428
}
429429

430-
if (!meta && !blob) {
431-
return RedisModule_ReplyWithError(ctx, "ERR no META or BLOB specified");
432-
}
433-
434430
char *buffer = NULL;
435431
size_t len = 0;
436432

437-
if (blob) {
433+
if (!meta || blob) {
438434
RAI_ModelSerialize(mto, &buffer, &len, &err);
439-
if (err.code != RAI_OK) {
440-
#ifdef RAI_PRINT_BACKEND_ERRORS
441-
printf("ERR: %s\n", err.detail);
442-
#endif
443-
int ret = RedisModule_ReplyWithError(ctx, err.detail);
435+
if (RAI_GetErrorCode(&err) != RAI_OK) {
436+
int ret = RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
444437
RAI_ClearError(&err);
445438
if (*buffer) {
446439
RedisModule_Free(buffer);
@@ -455,12 +448,14 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
455448
return REDISMODULE_OK;
456449
}
457450

458-
const int outentries = blob ? 18 : 16;
459-
RedisModule_ReplyWithArray(ctx, outentries);
451+
// The only case where we return only META, is when META is given but BLOB
452+
// was not. Otherwise, we return both META+SOURCE
453+
const int out_entries = (meta && !blob) ? 16 : 18;
454+
RedisModule_ReplyWithArray(ctx, out_entries);
460455

461456
RedisModule_ReplyWithCString(ctx, "backend");
462-
const char *backendstr = RAI_GetBackendName(mto->backend);
463-
RedisModule_ReplyWithCString(ctx, backendstr);
457+
const char *backend_str = RAI_GetBackendName(mto->backend);
458+
RedisModule_ReplyWithCString(ctx, backend_str);
464459

465460
RedisModule_ReplyWithCString(ctx, "device");
466461
RedisModule_ReplyWithCString(ctx, mto->devicestr);
@@ -495,7 +490,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
495490
RedisModule_ReplyWithCString(ctx, "minbatchtimeout");
496491
RedisModule_ReplyWithLongLong(ctx, (long)mto->opts.minbatchtimeout);
497492

498-
if (meta && blob) {
493+
// This condition is the negation of (meta && !blob)
494+
if (!meta || blob) {
499495
RedisModule_ReplyWithCString(ctx, "blob");
500496
RAI_ReplyWithChunks(ctx, buffer, len);
501497
RedisModule_Free(buffer);
@@ -651,43 +647,40 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
651647
return REDISMODULE_ERR;
652648
}
653649

654-
int meta = 0;
655-
int source = 0;
650+
bool meta = false; // Indicates whether META argument was given.
651+
bool source = false; // Indicates whether SOURCE argument was given.
656652
for (int i = 2; i < argc; i++) {
657653
const char *optstr = RedisModule_StringPtrLen(argv[i], NULL);
658654
if (!strcasecmp(optstr, "META")) {
659-
meta = 1;
655+
meta = true;
660656
} else if (!strcasecmp(optstr, "SOURCE")) {
661-
source = 1;
657+
source = true;
662658
}
663659
}
664-
665-
if (!meta && !source) {
666-
return RedisModule_ReplyWithError(ctx, "ERR no META or SOURCE specified");
667-
}
668-
660+
// If only SOURCE arg was given, return only the script source.
669661
if (!meta && source) {
670662
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
671663
return REDISMODULE_OK;
672664
}
665+
// We return (META+SOURCE) if both args are given, or if none of them was given.
666+
// The only case where we return only META data, is if META is given while SOURCE was not.
667+
int out_entries = (source || !meta) ? 8 : 6;
668+
RedisModule_ReplyWithArray(ctx, out_entries);
673669

674-
int outentries = source ? 8 : 6;
675-
676-
RedisModule_ReplyWithArray(ctx, outentries);
677670
RedisModule_ReplyWithCString(ctx, "device");
678671
RedisModule_ReplyWithCString(ctx, sto->devicestr);
679672
RedisModule_ReplyWithCString(ctx, "tag");
680673
RedisModule_ReplyWithString(ctx, sto->tag);
681-
if (source) {
682-
RedisModule_ReplyWithCString(ctx, "source");
683-
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
684-
}
685674
RedisModule_ReplyWithCString(ctx, "Entry Points");
686675
size_t nEntryPoints = array_len(sto->entryPoints);
687676
RedisModule_ReplyWithArray(ctx, nEntryPoints);
688677
for (size_t i = 0; i < nEntryPoints; i++) {
689678
RedisModule_ReplyWithCString(ctx, sto->entryPoints[i]);
690679
}
680+
if (source || !meta) {
681+
RedisModule_ReplyWithCString(ctx, "source");
682+
RedisModule_ReplyWithCString(ctx, sto->scriptdef);
683+
}
691684
return REDISMODULE_OK;
692685
}
693686

tests/flow/tests_commands.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_modelstore_errors(env):
5454
'AI.MODELSTORE', 'm{1}', 'TORCH', DEVICE, 'BATCHSIZE', 2, 'BLOB')
5555

5656

57-
def test_modelget_errors(env):
57+
def test_modelget(env):
5858
if not TEST_TF:
5959
env.debugPrint("Skipping test since TF is not available", force=True)
6060
return
@@ -69,8 +69,17 @@ def test_modelget_errors(env):
6969

7070
# ERR model key is empty
7171
con.execute_command('DEL', 'DONT_EXIST{1}')
72-
check_error_message(env, con, "model key is empty",
73-
'AI.MODELGET', 'DONT_EXIST{1}')
72+
check_error_message(env, con, "model key is empty", 'AI.MODELGET', 'DONT_EXIST{1}')
73+
74+
# The default behaviour on success is return META+BLOB
75+
model_pb = load_file_content('graph.pb')
76+
con.execute_command('AI.MODELSTORE', 'm{1}', 'TF', DEVICE, 'INPUTS', 2, 'a', 'b', 'OUTPUTS', 1, 'mul',
77+
'BLOB', model_pb)
78+
_, backend, _, device, _, tag, _, batchsize, _, minbatchsize, _, inputs, _, outputs, _, minbatchtimeout, _, blob = \
79+
con.execute_command('AI.MODELGET', 'm{1}')
80+
env.assertEqual([backend, device, tag, batchsize, minbatchsize, minbatchtimeout, inputs, outputs],
81+
[b"TF", bytes(DEVICE, "utf8"), b"", 0, 0, 0, [b"a", b"b"], [b"mul"]])
82+
env.assertEqual(blob, model_pb)
7483

7584

7685
def test_modelexecute_errors(env):

tests/flow/tests_deprecated_commands.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ def test_pytorch_scriptset(env):
239239

240240
ret = con.execute_command('AI.SCRIPTSET', 'ket{1}', DEVICE, 'TAG', 'asdf', 'SOURCE', script)
241241
env.assertEqual(ret, b'OK')
242+
_, device, _, tag, _, entry_points, _, source = con.execute_command('AI.SCRIPTGET', 'ket{1}')
243+
env.assertEqual([device, tag, entry_points, source], [bytes(DEVICE, "utf8"), b"asdf", [], script])
242244

243245
ensureSlaveSynced(con, env)
244246

0 commit comments

Comments
 (0)