Skip to content

Commit 6de3a1c

Browse files
author
DvirDukhan
committed
fixed PR comments
1 parent 6b4ea89 commit 6de3a1c

File tree

8 files changed

+113
-104
lines changed

8 files changed

+113
-104
lines changed

src/backends/libtorch_c/torch_c.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ extern "C" void torchRunScript(void *scriptCtx, const char *fnName,
364364
torch::Device device(device_type, ctx->device_id);
365365

366366
torch::jit::Stack stack;
367-
if(inputsCtx->noEntryPoint) {
367+
if(!inputsCtx->hasEntryPoint) {
368368
/* In case of no entry point, this might be due to a usage in a script set by the deprecated API.
369369
* In this case, until SCRIPTSET is EOL we will allow functions, called by SCRIPTRUN or SCRIPTEXECUTE, and those
370370
* functions are not in the endpoint set to be executed in "best effort" manner.

src/backends/libtorch_c/torch_c.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ typedef struct TorchFunctionInputCtx {
1515
size_t argsCount;
1616
RedisModuleString **keys;
1717
size_t keysCount;
18-
bool noEntryPoint; // TODO: remove this when SCRIPTRUN is EOL. Indication that the script was
19-
// stored with SCRIPTSET and and SCRIPTSTORE, such that it has no entry
20-
// point, so execution is best effort.
18+
bool hasEntryPoint; // TODO: remove this when SCRIPTRUN is EOL. Indication that the script was
19+
// stored with SCRIPTSET and not SCRIPTSTORE, such that it has no entry
20+
// point, so execution is best effort.
2121
} TorchFunctionInputCtx;
2222

2323
/**
@@ -157,7 +157,7 @@ const char *torchScript_FunctionName(void *scriptCtx, size_t fn_index);
157157
size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index);
158158

159159
/**
160-
* @brief Return the number of arguments in the fuction numbered fn_index in the script.
160+
* @brief Return the number of arguments of a given fuction in the script.
161161
*
162162
* @param scriptCtx Script context.
163163
* @param functionName Function name.
@@ -166,7 +166,7 @@ size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index);
166166
size_t torchScript_FunctionArgumentCountByFunctionName(void *scriptCtx, const char *functionName);
167167

168168
/**
169-
* @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the
169+
* @brief Returns the type of the argument at arg_index of function numbered fn_index in the
170170
* script.
171171
*
172172
* @param scriptCtx Script context.
@@ -178,7 +178,7 @@ TorchScriptFunctionArgumentType torchScript_FunctionArgumentType(void *scriptCtx
178178
size_t arg_index);
179179

180180
/**
181-
* @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the
181+
* @brief Returns the type of the argument at arg_index of a given function in the
182182
* script.
183183
*
184184
* @param scriptCtx Script context.

src/backends/torch.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,10 @@ int RAI_ScriptRunTorch(RAI_Script *script, const char *function, RAI_ExecutionCt
434434
RAI_ScriptRunCtx *sctx = (RAI_ScriptRunCtx *)ectx;
435435

436436
// TODO: remove when SCRIPTRUN is EOL.
437-
bool noEntryPoint = true;
437+
bool hasEntryPoint = false;
438438
for (size_t i = 0; i < array_len(script->entryPoints); i++) {
439439
if (strcmp(function, script->entryPoints[i]) == 0) {
440-
noEntryPoint = false;
440+
hasEntryPoint = true;
441441
break;
442442
}
443443
}
@@ -449,7 +449,7 @@ int RAI_ScriptRunTorch(RAI_Script *script, const char *function, RAI_ExecutionCt
449449
.argsCount = array_len(sctx->args),
450450
.keys = sctx->keys,
451451
.keysCount = array_len(sctx->keys),
452-
.noEntryPoint = noEntryPoint};
452+
.hasEntryPoint = hasEntryPoint};
453453

454454
torchRunScript(script->script, function, &inputsCtx, outputs, nOutputs, &error_descr);
455455

src/execution/parsing/deprecated.c

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,97 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
342342
return REDISMODULE_OK;
343343
}
344344

345+
int ScriptSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
346+
if (argc != 5 && argc != 7)
347+
return RedisModule_WrongArity(ctx);
348+
349+
ArgsCursor ac;
350+
ArgsCursor_InitRString(&ac, argv + 1, argc - 1);
351+
352+
RedisModuleString *keystr;
353+
AC_GetRString(&ac, &keystr, 0);
354+
355+
const char *devicestr;
356+
AC_GetString(&ac, &devicestr, NULL, 0);
357+
358+
RedisModuleString *tag = NULL;
359+
if (AC_AdvanceIfMatch(&ac, "TAG")) {
360+
AC_GetRString(&ac, &tag, 0);
361+
}
362+
363+
if (AC_IsAtEnd(&ac)) {
364+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
365+
}
366+
367+
size_t scriptlen;
368+
const char *scriptdef = NULL;
369+
370+
if (AC_AdvanceIfMatch(&ac, "SOURCE")) {
371+
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
372+
}
373+
374+
if (scriptdef == NULL) {
375+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
376+
}
377+
378+
RAI_Script *script = NULL;
379+
380+
RAI_Error err = {0};
381+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
382+
383+
if (err.code == RAI_EBACKENDNOTLOADED) {
384+
RedisModule_Log(ctx, "warning",
385+
"Backend TORCH not loaded, will try loading default backend");
386+
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
387+
if (ret == REDISMODULE_ERR) {
388+
RedisModule_Log(ctx, "warning", "Could not load TORCH default backend");
389+
int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend");
390+
RAI_ClearError(&err);
391+
return ret;
392+
}
393+
RAI_ClearError(&err);
394+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
395+
}
396+
397+
if (err.code != RAI_OK) {
398+
#ifdef RAI_PRINT_BACKEND_ERRORS
399+
printf("ERR: %s\n", err.detail);
400+
#endif
401+
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
402+
RAI_ClearError(&err);
403+
return ret;
404+
}
405+
406+
if (!RunQueue_IsExists(devicestr)) {
407+
RunQueueInfo *run_queue_info = RunQueue_Create(devicestr);
408+
if (run_queue_info == NULL) {
409+
RAI_ScriptFree(script, &err);
410+
RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device");
411+
}
412+
}
413+
414+
RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE);
415+
int type = RedisModule_KeyType(key);
416+
if (type != REDISMODULE_KEYTYPE_EMPTY &&
417+
!(type == REDISMODULE_KEYTYPE_MODULE &&
418+
RedisModule_ModuleTypeGetType(key) == RAI_ScriptRedisType())) {
419+
RedisModule_CloseKey(key);
420+
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
421+
}
422+
423+
RedisModule_ModuleTypeSetValue(key, RAI_ScriptRedisType(), script);
424+
425+
script->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
426+
427+
RedisModule_CloseKey(key);
428+
429+
RedisModule_ReplyWithSimpleString(ctx, "OK");
430+
431+
RedisModule_ReplicateVerbatim(ctx);
432+
433+
return REDISMODULE_OK;
434+
}
435+
345436
static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
346437
RAI_Error *error, RedisModuleString ***inkeys,
347438
RedisModuleString ***outkeys, long long *timeout) {

src/execution/parsing/deprecated.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisMod
1818

1919
int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
2020

21+
int ScriptSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
22+
2123
/**
2224
* @brief Parse the arguments of the given ops in the DAGRUN command and build every op accordingly.
2325
* @param rinfo The DAG run info that will be populated with the ops if they are valid.

src/redisai.c

Lines changed: 5 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -718,101 +718,14 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
718718
* AI.SCRIPTSET script_key device [TAG tag] SOURCE script_source
719719
*/
720720
int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
721-
if (argc != 5 && argc != 7)
722-
return RedisModule_WrongArity(ctx);
723-
724-
ArgsCursor ac;
725-
ArgsCursor_InitRString(&ac, argv + 1, argc - 1);
726-
727-
RedisModuleString *keystr;
728-
AC_GetRString(&ac, &keystr, 0);
729-
730-
const char *devicestr;
731-
AC_GetString(&ac, &devicestr, NULL, 0);
732-
733-
RedisModuleString *tag = NULL;
734-
if (AC_AdvanceIfMatch(&ac, "TAG")) {
735-
AC_GetRString(&ac, &tag, 0);
736-
}
737-
738-
if (AC_IsAtEnd(&ac)) {
739-
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
740-
}
741-
742-
size_t scriptlen;
743-
const char *scriptdef = NULL;
744-
745-
if (AC_AdvanceIfMatch(&ac, "SOURCE")) {
746-
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
747-
}
748-
749-
if (scriptdef == NULL) {
750-
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
751-
}
752-
753-
RAI_Script *script = NULL;
754-
755-
RAI_Error err = {0};
756-
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
757-
758-
if (err.code == RAI_EBACKENDNOTLOADED) {
759-
RedisModule_Log(ctx, "warning",
760-
"Backend TORCH not loaded, will try loading default backend");
761-
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
762-
if (ret == REDISMODULE_ERR) {
763-
RedisModule_Log(ctx, "warning", "Could not load TORCH default backend");
764-
int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend");
765-
RAI_ClearError(&err);
766-
return ret;
767-
}
768-
RAI_ClearError(&err);
769-
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
770-
}
771-
772-
if (err.code != RAI_OK) {
773-
#ifdef RAI_PRINT_BACKEND_ERRORS
774-
printf("ERR: %s\n", err.detail);
775-
#endif
776-
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
777-
RAI_ClearError(&err);
778-
return ret;
779-
}
780-
781-
if (!RunQueue_IsExists(devicestr)) {
782-
RunQueueInfo *run_queue_info = RunQueue_Create(devicestr);
783-
if (run_queue_info == NULL) {
784-
RAI_ScriptFree(script, &err);
785-
RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device");
786-
}
787-
}
788-
789-
RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE);
790-
int type = RedisModule_KeyType(key);
791-
if (type != REDISMODULE_KEYTYPE_EMPTY &&
792-
!(type == REDISMODULE_KEYTYPE_MODULE &&
793-
RedisModule_ModuleTypeGetType(key) == RAI_ScriptRedisType())) {
794-
RedisModule_CloseKey(key);
795-
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
796-
}
797-
798-
RedisModule_ModuleTypeSetValue(key, RAI_ScriptRedisType(), script);
799-
800-
script->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
801-
802-
RedisModule_CloseKey(key);
803-
804-
RedisModule_ReplyWithSimpleString(ctx, "OK");
805-
806-
RedisModule_ReplicateVerbatim(ctx);
807-
808-
return REDISMODULE_OK;
721+
RedisModule_Log(ctx, "warning",
722+
"AI.SCRIPTSET command is deprecated and will"
723+
" not be available in future version, you can use AI.SCRIPTSTORE instead");
724+
return ScriptSetCommand(ctx, argv, argc);
809725
}
810726

811-
/*
812-
* Todo: this is temporary until we implement the new command, for testing broadcast in DMC
813-
*/
814727
int RedisAI_ScriptStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
815-
// AI.SCRIPTSET <key> <device> ENTRY_POINTS 1 ep1 SOURCE blob
728+
// AI.SCRIPTSTORE <key> <device> ENTRY_POINTS 1 ep1 SOURCE blob
816729
if (argc < 8)
817730
return RedisModule_WrongArity(ctx);
818731

src/serialization/AOF/rai_aof_rewrite.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
110110

111111
void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *value) {
112112
RAI_Script *script = (RAI_Script *)value;
113-
RedisModuleString **args = array_new(RedisModuleString *, 4);
113+
RedisModuleString **args = array_new(RedisModuleString *, 9);
114114
args = array_append(args, RedisModule_CreateStringFromString(NULL, key));
115115
args = array_append(
116116
args, RedisModule_CreateString(NULL, script->devicestr, strlen(script->devicestr)));
@@ -131,7 +131,7 @@ void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *valu
131131
args = array_append(
132132
args, RedisModule_CreateString(NULL, script->scriptdef, strlen(script->scriptdef)));
133133

134-
RedisModule_EmitAOF(aof, "AI.SCRIPTSTORE", "v", args);
134+
RedisModule_EmitAOF(aof, "AI.SCRIPTSTORE", "v", args, array_len(args));
135135
for (size_t i = 0; i < array_len(args); i++) {
136136
RedisModule_FreeString(NULL, args[i]);
137137
}

tests/flow/tests_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ def test_pytorch_scriptexecute_errors(env):
457457

458458
check_error_message(env, con, "KEYS scope must be provided first for AI.SCRIPTEXECUTE command", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS')
459459

460+
check_error_message(env, con, "Invalid value for TIMEOUT",'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT', 'TIMEOUT')
461+
462+
460463
if env.isCluster():
461464
# cross shard
462465
check_error_message(env, con, "CROSSSLOT Keys in request don't hash to the same slot", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{2}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}')

0 commit comments

Comments
 (0)