Skip to content

Commit 4e67931

Browse files
author
DvirDukhan
committed
fixed PR comments
1 parent c12c692 commit 4e67931

File tree

8 files changed

+113
-104
lines changed

8 files changed

+113
-104
lines changed

src/backends/libtorch_c/torch_c.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ extern "C" void torchRunScript(void *scriptCtx, const char *fnName,
338338
torch::Device device(device_type, ctx->device_id);
339339

340340
torch::jit::Stack stack;
341-
if(inputsCtx->noEntryPoint) {
341+
if(!inputsCtx->hasEntryPoint) {
342342
/* In case of no entry point, this might be due to a usage in a script set by the deprecated API.
343343
* In this case, until SCRIPTSET is EOL we will allow functions, called by SCRIPTRUN or SCRIPTEXECUTE, and those
344344
* functions are not in the endpoint set to be executed in "best effort" manner.

src/backends/libtorch_c/torch_c.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ typedef struct TorchFunctionInputCtx {
1515
size_t argsCount;
1616
RedisModuleString **keys;
1717
size_t keysCount;
18-
bool noEntryPoint; // TODO: remove this when SCRIPTRUN is EOL. Indication that the script was
19-
// stored with SCRIPTSET and and SCRIPTSTORE, such that it has no entry
20-
// point, so execution is best effort.
18+
bool hasEntryPoint; // TODO: remove this when SCRIPTRUN is EOL. Indication that the script was
19+
// stored with SCRIPTSET and not SCRIPTSTORE, such that it has no entry
20+
// point, so execution is best effort.
2121
} TorchFunctionInputCtx;
2222

2323
/**
@@ -157,7 +157,7 @@ const char *torchScript_FunctionName(void *scriptCtx, size_t fn_index);
157157
size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index);
158158

159159
/**
160-
* @brief Return the number of arguments in the fuction numbered fn_index in the script.
160+
* @brief Return the number of arguments of a given fuction in the script.
161161
*
162162
* @param scriptCtx Script context.
163163
* @param functionName Function name.
@@ -166,7 +166,7 @@ size_t torchScript_FunctionArgumentCount(void *scriptCtx, size_t fn_index);
166166
size_t torchScript_FunctionArgumentCountByFunctionName(void *scriptCtx, const char *functionName);
167167

168168
/**
169-
* @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the
169+
* @brief Returns the type of the argument at arg_index of function numbered fn_index in the
170170
* script.
171171
*
172172
* @param scriptCtx Script context.
@@ -178,7 +178,7 @@ TorchScriptFunctionArgumentType torchScript_FunctionArgumentType(void *scriptCtx
178178
size_t arg_index);
179179

180180
/**
181-
* @brief Rerturns the type of the argument at arg_index of function numbered fn_index in the
181+
* @brief Returns the type of the argument at arg_index of a given function in the
182182
* script.
183183
*
184184
* @param scriptCtx Script context.

src/backends/torch.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,10 @@ int RAI_ScriptRunTorch(RAI_Script *script, const char *function, RAI_ExecutionCt
414414
RAI_ScriptRunCtx *sctx = (RAI_ScriptRunCtx *)ectx;
415415

416416
// TODO: remove when SCRIPTRUN is EOL.
417-
bool noEntryPoint = true;
417+
bool hasEntryPoint = false;
418418
for (size_t i = 0; i < array_len(script->entryPoints); i++) {
419419
if (strcmp(function, script->entryPoints[i]) == 0) {
420-
noEntryPoint = false;
420+
hasEntryPoint = true;
421421
break;
422422
}
423423
}
@@ -429,7 +429,7 @@ int RAI_ScriptRunTorch(RAI_Script *script, const char *function, RAI_ExecutionCt
429429
.argsCount = array_len(sctx->args),
430430
.keys = sctx->keys,
431431
.keysCount = array_len(sctx->keys),
432-
.noEntryPoint = noEntryPoint};
432+
.hasEntryPoint = hasEntryPoint};
433433

434434
torchRunScript(script->script, function, &inputsCtx, outputs, nOutputs, &error_descr);
435435

src/execution/parsing/deprecated.c

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,97 @@ int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
342342
return REDISMODULE_OK;
343343
}
344344

345+
int ScriptSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
346+
if (argc != 5 && argc != 7)
347+
return RedisModule_WrongArity(ctx);
348+
349+
ArgsCursor ac;
350+
ArgsCursor_InitRString(&ac, argv + 1, argc - 1);
351+
352+
RedisModuleString *keystr;
353+
AC_GetRString(&ac, &keystr, 0);
354+
355+
const char *devicestr;
356+
AC_GetString(&ac, &devicestr, NULL, 0);
357+
358+
RedisModuleString *tag = NULL;
359+
if (AC_AdvanceIfMatch(&ac, "TAG")) {
360+
AC_GetRString(&ac, &tag, 0);
361+
}
362+
363+
if (AC_IsAtEnd(&ac)) {
364+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
365+
}
366+
367+
size_t scriptlen;
368+
const char *scriptdef = NULL;
369+
370+
if (AC_AdvanceIfMatch(&ac, "SOURCE")) {
371+
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
372+
}
373+
374+
if (scriptdef == NULL) {
375+
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
376+
}
377+
378+
RAI_Script *script = NULL;
379+
380+
RAI_Error err = {0};
381+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
382+
383+
if (err.code == RAI_EBACKENDNOTLOADED) {
384+
RedisModule_Log(ctx, "warning",
385+
"Backend TORCH not loaded, will try loading default backend");
386+
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
387+
if (ret == REDISMODULE_ERR) {
388+
RedisModule_Log(ctx, "warning", "Could not load TORCH default backend");
389+
int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend");
390+
RAI_ClearError(&err);
391+
return ret;
392+
}
393+
RAI_ClearError(&err);
394+
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
395+
}
396+
397+
if (err.code != RAI_OK) {
398+
#ifdef RAI_PRINT_BACKEND_ERRORS
399+
printf("ERR: %s\n", err.detail);
400+
#endif
401+
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
402+
RAI_ClearError(&err);
403+
return ret;
404+
}
405+
406+
if (!RunQueue_IsExists(devicestr)) {
407+
RunQueueInfo *run_queue_info = RunQueue_Create(devicestr);
408+
if (run_queue_info == NULL) {
409+
RAI_ScriptFree(script, &err);
410+
RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device");
411+
}
412+
}
413+
414+
RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE);
415+
int type = RedisModule_KeyType(key);
416+
if (type != REDISMODULE_KEYTYPE_EMPTY &&
417+
!(type == REDISMODULE_KEYTYPE_MODULE &&
418+
RedisModule_ModuleTypeGetType(key) == RAI_ScriptRedisType())) {
419+
RedisModule_CloseKey(key);
420+
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
421+
}
422+
423+
RedisModule_ModuleTypeSetValue(key, RAI_ScriptRedisType(), script);
424+
425+
script->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
426+
427+
RedisModule_CloseKey(key);
428+
429+
RedisModule_ReplyWithSimpleString(ctx, "OK");
430+
431+
RedisModule_ReplicateVerbatim(ctx);
432+
433+
return REDISMODULE_OK;
434+
}
435+
345436
static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
346437
RAI_Error *error, RedisModuleString ***inkeys,
347438
RedisModuleString ***outkeys, long long *timeout) {

src/execution/parsing/deprecated.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ int ParseScriptRunCommand(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp, RedisMod
1818

1919
int ModelSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
2020

21+
int ScriptSetCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
22+
2123
/**
2224
* @brief Parse the arguments of the given ops in the DAGRUN command and build every op accordingly.
2325
* @param rinfo The DAG run info that will be populated with the ops if they are valid.

src/redisai.c

Lines changed: 5 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -722,101 +722,14 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
722722
* AI.SCRIPTSET script_key device [TAG tag] SOURCE script_source
723723
*/
724724
int RedisAI_ScriptSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
725-
if (argc != 5 && argc != 7)
726-
return RedisModule_WrongArity(ctx);
727-
728-
ArgsCursor ac;
729-
ArgsCursor_InitRString(&ac, argv + 1, argc - 1);
730-
731-
RedisModuleString *keystr;
732-
AC_GetRString(&ac, &keystr, 0);
733-
734-
const char *devicestr;
735-
AC_GetString(&ac, &devicestr, NULL, 0);
736-
737-
RedisModuleString *tag = NULL;
738-
if (AC_AdvanceIfMatch(&ac, "TAG")) {
739-
AC_GetRString(&ac, &tag, 0);
740-
}
741-
742-
if (AC_IsAtEnd(&ac)) {
743-
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
744-
}
745-
746-
size_t scriptlen;
747-
const char *scriptdef = NULL;
748-
749-
if (AC_AdvanceIfMatch(&ac, "SOURCE")) {
750-
AC_GetString(&ac, &scriptdef, &scriptlen, 0);
751-
}
752-
753-
if (scriptdef == NULL) {
754-
return RedisModule_ReplyWithError(ctx, "ERR Insufficient arguments, missing script SOURCE");
755-
}
756-
757-
RAI_Script *script = NULL;
758-
759-
RAI_Error err = {0};
760-
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
761-
762-
if (err.code == RAI_EBACKENDNOTLOADED) {
763-
RedisModule_Log(ctx, "warning",
764-
"Backend TORCH not loaded, will try loading default backend");
765-
int ret = RAI_LoadDefaultBackend(ctx, RAI_BACKEND_TORCH);
766-
if (ret == REDISMODULE_ERR) {
767-
RedisModule_Log(ctx, "warning", "Could not load TORCH default backend");
768-
int ret = RedisModule_ReplyWithError(ctx, "ERR Could not load backend");
769-
RAI_ClearError(&err);
770-
return ret;
771-
}
772-
RAI_ClearError(&err);
773-
script = RAI_ScriptCreate(devicestr, tag, scriptdef, &err);
774-
}
775-
776-
if (err.code != RAI_OK) {
777-
#ifdef RAI_PRINT_BACKEND_ERRORS
778-
printf("ERR: %s\n", err.detail);
779-
#endif
780-
int ret = RedisModule_ReplyWithError(ctx, err.detail_oneline);
781-
RAI_ClearError(&err);
782-
return ret;
783-
}
784-
785-
if (!RunQueue_IsExists(devicestr)) {
786-
RunQueueInfo *run_queue_info = RunQueue_Create(devicestr);
787-
if (run_queue_info == NULL) {
788-
RAI_ScriptFree(script, &err);
789-
RedisModule_ReplyWithError(ctx, "ERR Could not initialize queue on requested device");
790-
}
791-
}
792-
793-
RedisModuleKey *key = RedisModule_OpenKey(ctx, keystr, REDISMODULE_READ | REDISMODULE_WRITE);
794-
int type = RedisModule_KeyType(key);
795-
if (type != REDISMODULE_KEYTYPE_EMPTY &&
796-
!(type == REDISMODULE_KEYTYPE_MODULE &&
797-
RedisModule_ModuleTypeGetType(key) == RAI_ScriptRedisType())) {
798-
RedisModule_CloseKey(key);
799-
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
800-
}
801-
802-
RedisModule_ModuleTypeSetValue(key, RAI_ScriptRedisType(), script);
803-
804-
script->infokey = RAI_AddStatsEntry(ctx, keystr, RAI_SCRIPT, RAI_BACKEND_TORCH, devicestr, tag);
805-
806-
RedisModule_CloseKey(key);
807-
808-
RedisModule_ReplyWithSimpleString(ctx, "OK");
809-
810-
RedisModule_ReplicateVerbatim(ctx);
811-
812-
return REDISMODULE_OK;
725+
RedisModule_Log(ctx, "warning",
726+
"AI.SCRIPTSET command is deprecated and will"
727+
" not be available in future version, you can use AI.SCRIPTSTORE instead");
728+
return ScriptSetCommand(ctx, argv, argc);
813729
}
814730

815-
/*
816-
* Todo: this is temporary until we implement the new command, for testing broadcast in DMC
817-
*/
818731
int RedisAI_ScriptStore_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
819-
// AI.SCRIPTSET <key> <device> ENTRY_POINTS 1 ep1 SOURCE blob
732+
// AI.SCRIPTSTORE <key> <device> ENTRY_POINTS 1 ep1 SOURCE blob
820733
if (argc < 8)
821734
return RedisModule_WrongArity(ctx);
822735

src/serialization/AOF/rai_aof_rewrite.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void RAI_AOFRewriteModel(RedisModuleIO *aof, RedisModuleString *key, void *value
110110

111111
void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *value) {
112112
RAI_Script *script = (RAI_Script *)value;
113-
RedisModuleString **args = array_new(RedisModuleString *, 4);
113+
RedisModuleString **args = array_new(RedisModuleString *, 9);
114114
args = array_append(args, RedisModule_CreateStringFromString(NULL, key));
115115
args = array_append(
116116
args, RedisModule_CreateString(NULL, script->devicestr, strlen(script->devicestr)));
@@ -131,7 +131,7 @@ void RAI_AOFRewriteScript(RedisModuleIO *aof, RedisModuleString *key, void *valu
131131
args = array_append(
132132
args, RedisModule_CreateString(NULL, script->scriptdef, strlen(script->scriptdef)));
133133

134-
RedisModule_EmitAOF(aof, "AI.SCRIPTSTORE", "v", args);
134+
RedisModule_EmitAOF(aof, "AI.SCRIPTSTORE", "v", args, array_len(args));
135135
for (size_t i = 0; i < array_len(args); i++) {
136136
RedisModule_FreeString(NULL, args[i]);
137137
}

tests/flow/tests_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ def test_pytorch_scriptexecute_errors(env):
457457

458458
check_error_message(env, con, "KEYS scope must be provided first for AI.SCRIPTEXECUTE command", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'INPUTS', 'OUTPUTS')
459459

460+
check_error_message(env, con, "Invalid value for TIMEOUT",'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1, '{1}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}', 'TIMEOUT', 'TIMEOUT')
461+
462+
460463
if env.isCluster():
461464
# cross shard
462465
check_error_message(env, con, "CROSSSLOT Keys in request don't hash to the same slot", 'AI.SCRIPTEXECUTE', 'ket{1}', 'bar', 'KEYS', 1 , '{2}', 'INPUTS', 2, 'a{1}', 'b{1}', 'OUTPUTS', 1, 'c{1}')

0 commit comments

Comments
 (0)