Skip to content

Commit 87df138

Browse files
committed
[WIP] Enable AI.SCRIPTRUN on AI.DAGRUN* (#383)
* [add] decoupled scriptrun command parsing from runtime scriptrun. * [add] split positive/negative tests on pytorch scriptrun. * [add] refactor AI.DAGRUN_RO and AI.DAGRUN to use the same code base (with read/write modes) * [add] added positive and negative tests for dagrun with scriptrun * [add] updated documentation to reflect scriptrun support on dagrun * [add] added example enqueuing multiple SCRIPTRUN and MODELRUN commands within a DAG
1 parent 94d79e5 commit 87df138

18 files changed

+717
-346
lines changed

docs/commands.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ _Arguments_
553553
* `AI.TENSORSET`
554554
* `AI.TENSORGET`
555555
* `AI.MODELRUN`
556+
* `AI.SCRIPTRUN`
556557

557558
_Return_
558559

@@ -576,6 +577,30 @@ redis> AI.DAGRUN PERSIST 1 predictions |>
576577
3) "\x00\x00\x80?\x00\x00\x00@\x00\x00@@\x00\x00\x80@"
577578
```
578579

580+
A common pattern is enqueuing multiple SCRIPTRUN and MODELRUN commands within a DAG. The following example uses ResNet-50,to classify images into 1000 object categories. Given that our input tensor contains each color represented as a 8-bit integer and that neural networks usually work with floating-point tensors as their input we need to cast a tensor to floating-point and normalize the values of the pixels - for that we will use `pre_process_3ch` function.
581+
582+
To optimize the classification process we can use a post process script to return only the category position with the maximum classification - for that we will use `post_process` script. Using the DAG capabilities we've removed the necessity of storing the intermediate tensors in the keyspace. You can even run the entire process without storing the output tensor, as follows:
583+
584+
```
585+
redis> AI.DAGRUN_RO |>
586+
AI.TENSORSET image UINT8 224 224 3 BLOB b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00....' |>
587+
AI.SCRIPTRUN imagenet_script pre_process_3ch INPUTS image OUTPUTS temp_key1 |>
588+
AI.MODELRUN imagenet_model INPUTS temp_key1 OUTPUTS temp_key2 |>
589+
AI.SCRIPTRUN imagenet_script post_process INPUTS temp_key2 OUTPUTS output |>
590+
AI.TENSORGET output VALUES
591+
1) OK
592+
2) OK
593+
3) OK
594+
4) OK
595+
5) 1) 1) (integer) 111
596+
```
597+
598+
As visible on the array reply, the label position with higher classification was 111.
599+
600+
By combining DAG with multiple SCRIPTRUN and MODELRUN commands we've substantially removed the overall required bandwith and network RX ( we're now returning a tensor with 1000 times less elements per classification ).
601+
602+
603+
579604
!!! warning "Intermediate memory overhead"
580605
The execution of models and scripts within the DAG may generate intermediate tensors that are not allocated by the Redis allocator, but by whatever allocator is used in the backends (which may act on main memory or GPU memory, depending on the device), thus not being limited by `maxmemory` configuration settings of Redis.
581606

src/dag.c

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,38 @@ void *RedisAI_DagRunSession(RedisAI_RunInfo *rinfo) {
106106
}
107107
break;
108108
}
109+
case REDISAI_DAG_CMD_SCRIPTRUN: {
110+
const int parse_result = RedisAI_Parse_ScriptRun_RedisCommand(
111+
NULL, currentOp->argv, currentOp->argc, &(currentOp->sctx),
112+
&(currentOp->outkeys), &(currentOp->sctx->script), 1,
113+
&(rinfo->dagTensorsContext), 0, NULL, currentOp->err);
114+
115+
if (parse_result > 0) {
116+
currentOp->result = REDISMODULE_OK;
117+
const long long start = ustime();
118+
currentOp->result = RAI_ScriptRun(currentOp->sctx, currentOp->err);
119+
currentOp->duration_us = ustime() - start;
120+
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
121+
for (size_t outputNumber = 0; outputNumber < noutputs;
122+
outputNumber++) {
123+
RAI_Tensor *tensor =
124+
RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
125+
if (tensor) {
126+
const char *key_string = RedisModule_StringPtrLen(
127+
currentOp->outkeys[outputNumber], NULL);
128+
const char *dictKey = RedisModule_Strdup(key_string);
129+
AI_dictReplace(rinfo->dagTensorsContext, (void*)dictKey, tensor);
130+
} else {
131+
RAI_SetError(currentOp->err, RAI_EMODELRUN,
132+
"ERR output tensor on DAG's SCRIPTRUN was null");
133+
currentOp->result = REDISMODULE_ERR;
134+
}
135+
}
136+
} else {
137+
currentOp->result = REDISMODULE_ERR;
138+
}
139+
break;
140+
}
109141
default: {
110142
/* unsupported DAG's command */
111143
RAI_SetError(currentOp->err, RAI_EDAGRUN,
@@ -173,6 +205,22 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv,
173205
break;
174206
}
175207

208+
case REDISAI_DAG_CMD_SCRIPTRUN: {
209+
rinfo->dagReplyLength++;
210+
struct RedisAI_RunStats *rstats = NULL;
211+
const char *runkey =
212+
RedisModule_StringPtrLen(currentOp->runkey, NULL);
213+
RAI_GetRunStats(runkey,&rstats);
214+
if (currentOp->result == REDISMODULE_ERR) {
215+
RAI_SafeAddDataPoint(rstats,0,1,1,0);
216+
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
217+
} else {
218+
RAI_SafeAddDataPoint(rstats,currentOp->duration_us,1,0,0);
219+
RedisModule_ReplyWithSimpleString(ctx, "OK");
220+
}
221+
break;
222+
}
223+
176224
default:
177225
/* no-op */
178226
break;
@@ -334,3 +382,161 @@ int RAI_parseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv,
334382
}
335383
return argpos;
336384
}
385+
386+
int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
387+
int argc, int dagMode) {
388+
if (argc < 4) return RedisModule_WrongArity(ctx);
389+
390+
RedisAI_RunInfo *rinfo = NULL;
391+
if (RAI_InitRunInfo(&rinfo) == REDISMODULE_ERR) {
392+
return RedisModule_ReplyWithError(
393+
ctx,
394+
"ERR Unable to allocate the memory and initialise the RedisAI_RunInfo "
395+
"structure");
396+
}
397+
rinfo->use_local_context = 1;
398+
RAI_DagOp *currentDagOp = NULL;
399+
RAI_InitDagOp(&currentDagOp);
400+
array_append(rinfo->dagOps, currentDagOp);
401+
402+
int persistFlag = 0;
403+
int loadFlag = 0;
404+
int chainingOpCount = 0;
405+
const char *deviceStr = NULL;
406+
407+
for (size_t argpos = 1; argpos <= argc - 1; argpos++) {
408+
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
409+
if (!strcasecmp(arg_string, "LOAD")) {
410+
loadFlag = 1;
411+
const int parse_result = RAI_parseDAGLoadArgs(
412+
ctx, &argv[argpos], argc - argpos, &(rinfo->dagTensorsLoadedContext),
413+
&(rinfo->dagTensorsContext), "|>");
414+
if (parse_result > 0) {
415+
argpos += parse_result - 1;
416+
} else {
417+
RAI_FreeRunInfo(ctx, rinfo);
418+
return REDISMODULE_ERR;
419+
}
420+
} else if (!strcasecmp(arg_string, "PERSIST")) {
421+
if (dagMode == REDISAI_DAG_READONLY_MODE) {
422+
RAI_FreeRunInfo(ctx, rinfo);
423+
return RedisModule_ReplyWithError(
424+
ctx, "ERR PERSIST cannot be specified in a read-only DAG");
425+
}
426+
persistFlag = 1;
427+
const int parse_result =
428+
RAI_parseDAGPersistArgs(ctx, &argv[argpos], argc - argpos,
429+
&(rinfo->dagTensorsPersistentContext), "|>");
430+
if (parse_result > 0) {
431+
argpos += parse_result - 1;
432+
} else {
433+
RAI_FreeRunInfo(ctx, rinfo);
434+
return REDISMODULE_ERR;
435+
}
436+
} else if (!strcasecmp(arg_string, "|>")) {
437+
// on the first pipe operator, if LOAD or PERSIST were used, we've already
438+
// allocated memory
439+
if (!((persistFlag == 1 || loadFlag == 1) && chainingOpCount == 0)) {
440+
rinfo->dagNumberCommands++;
441+
RAI_DagOp *currentDagOp = NULL;
442+
RAI_InitDagOp(&currentDagOp);
443+
array_append(rinfo->dagOps, currentDagOp);
444+
}
445+
chainingOpCount++;
446+
} else {
447+
if (!strcasecmp(arg_string, "AI.TENSORGET")) {
448+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType =
449+
REDISAI_DAG_CMD_TENSORGET;
450+
}
451+
if (!strcasecmp(arg_string, "AI.TENSORSET")) {
452+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType =
453+
REDISAI_DAG_CMD_TENSORSET;
454+
}
455+
if (!strcasecmp(arg_string, "AI.MODELRUN")) {
456+
if (argc - 2 < argpos) {
457+
return RedisModule_WrongArity(ctx);
458+
}
459+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType =
460+
REDISAI_DAG_CMD_MODELRUN;
461+
RAI_Model *mto;
462+
RedisModuleKey *modelKey;
463+
const int status = RAI_GetModelFromKeyspace(
464+
ctx, argv[argpos + 1], &modelKey, &mto, REDISMODULE_READ);
465+
if (status == REDISMODULE_ERR) {
466+
RAI_FreeRunInfo(ctx, rinfo);
467+
return REDISMODULE_ERR;
468+
}
469+
if (deviceStr == NULL) {
470+
deviceStr = mto->devicestr;
471+
} else {
472+
// If the device strings are not equivalent, reply with error ( for
473+
// now )
474+
if (strcasecmp(mto->devicestr, deviceStr) != 0) {
475+
RAI_FreeRunInfo(ctx, rinfo);
476+
return RedisModule_ReplyWithError(
477+
ctx, "ERR multi-device DAGs not supported yet");
478+
}
479+
}
480+
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos + 1];
481+
rinfo->dagOps[rinfo->dagNumberCommands]->mctx =
482+
RAI_ModelRunCtxCreate(mto);
483+
}
484+
if (!strcasecmp(arg_string, "AI.SCRIPTRUN")) {
485+
if (argc - 3 < argpos) {
486+
return RedisModule_WrongArity(ctx);
487+
}
488+
rinfo->dagOps[rinfo->dagNumberCommands]->commandType =
489+
REDISAI_DAG_CMD_SCRIPTRUN;
490+
RAI_Script *sto;
491+
RedisModuleKey *scriptKey;
492+
const int status = RAI_GetScriptFromKeyspace(
493+
ctx, argv[argpos + 1], &scriptKey, &sto, REDISMODULE_READ);
494+
if (status == REDISMODULE_ERR) {
495+
RAI_FreeRunInfo(ctx, rinfo);
496+
return REDISMODULE_ERR;
497+
}
498+
if (deviceStr == NULL) {
499+
deviceStr = sto->devicestr;
500+
} else {
501+
// If the device strings are not equivalent, reply with error ( for
502+
// now )
503+
if (strcasecmp(sto->devicestr, deviceStr) != 0) {
504+
RAI_FreeRunInfo(ctx, rinfo);
505+
return RedisModule_ReplyWithError(
506+
ctx, "ERR multi-device DAGs not supported yet");
507+
}
508+
}
509+
const char *functionName =
510+
RedisModule_StringPtrLen(argv[argpos + 2], NULL);
511+
rinfo->dagOps[rinfo->dagNumberCommands]->runkey = argv[argpos + 1];
512+
rinfo->dagOps[rinfo->dagNumberCommands]->sctx =
513+
RAI_ScriptRunCtxCreate(sto, functionName);
514+
}
515+
RedisModule_RetainString(NULL, argv[argpos]);
516+
array_append(rinfo->dagOps[rinfo->dagNumberCommands]->argv, argv[argpos]);
517+
rinfo->dagOps[rinfo->dagNumberCommands]->argc++;
518+
}
519+
}
520+
521+
RunQueueInfo *run_queue_info = NULL;
522+
// If there was no MODELRUN or SCRIPTRUN on the DAG, we default all ops to CPU
523+
if (deviceStr == NULL) {
524+
deviceStr = "CPU";
525+
}
526+
// If the queue does not exist, initialize it
527+
if (ensureRunQueue(deviceStr, &run_queue_info) == REDISMODULE_ERR) {
528+
RAI_FreeRunInfo(ctx, rinfo);
529+
return RedisModule_ReplyWithError(ctx,
530+
"ERR Queue not initialized for device");
531+
}
532+
533+
rinfo->client =
534+
RedisModule_BlockClient(ctx, RedisAI_DagRun_Reply, NULL, NULL, 0);
535+
536+
pthread_mutex_lock(&run_queue_info->run_queue_mutex);
537+
queuePush(run_queue_info->run_queue, rinfo);
538+
pthread_cond_signal(&run_queue_info->queue_condition_var);
539+
pthread_mutex_unlock(&run_queue_info->run_queue_mutex);
540+
541+
return REDISMODULE_OK;
542+
}

src/dag.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,18 @@ int RAI_parseDAGPersistArgs(RedisModuleCtx *ctx, RedisModuleString **argv,
7373
int argc, AI_dict **localContextDict,
7474
const char *chaining_operator);
7575

76+
/**
77+
* DAGRUN and DAGRUN_RO parser, which reads the the sequence of
78+
* arguments and decides whether the sequence conforms to the syntax
79+
* specified by the DAG grammar.
80+
*
81+
* @param ctx Context in which Redis modules operate
82+
* @param argv Redis command arguments, as an array of strings
83+
* @param argc Redis command number of arguments
84+
* @param dagMode access mode, for now REDISAI_DAG_READONLY_MODE or REDISAI_DAG_WRITE_MODE
85+
* @return
86+
*/
87+
int RedisAI_DagRunSyntaxParser(RedisModuleCtx *ctx, RedisModuleString **argv,
88+
int argc, int dagMode);
89+
7690
#endif /* SRC_DAG_H_ */

src/model.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -559,11 +559,6 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
559559
if (!strcasecmp(arg_string, "OUTPUTS") && outputs_flag_count == 0) {
560560
is_input = 1;
561561
outputs_flag_count = 1;
562-
const size_t expected_noutputs = argc - argpos - 1;
563-
// if (expected_noutputs > 0) {
564-
// *outkeys =
565-
// RedisModule_Calloc(expected_noutputs, sizeof(RedisModuleString *));
566-
// }
567562
} else {
568563
RedisModule_RetainString(ctx, argv[argpos]);
569564
if (is_input == 0) {
@@ -613,6 +608,7 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
613608
} else {
614609
RedisModule_ReplyWithError(ctx, "ERR Output key not found");
615610
}
611+
return -1;
616612
}
617613
*outkeys=array_append(*outkeys,argv[argpos]);
618614
// (*outkeys)[noutputs] = argv[argpos];

0 commit comments

Comments
 (0)