Skip to content

Commit efa4219

Browse files
author
DvirDukhan
committed
added variadic to llapi
1 parent 24f02b7 commit efa4219

File tree

6 files changed

+63
-14
lines changed

6 files changed

+63
-14
lines changed

src/model.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,6 @@ int RedisAI_Parse_ModelRun_RedisCommand(RedisModuleCtx *ctx,
646646
return argpos;
647647
}
648648

649-
RedisModuleType *RAI_getModelRedisType(void) {
649+
RedisModuleType *RAI_ModelRedisType(void) {
650650
return RedisAI_ModelType;
651651
}

src/redisai.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,7 @@ static int RedisAI_RegisterApi(RedisModuleCtx* ctx) {
985985
REGISTER_API(ScriptFree, ctx);
986986
REGISTER_API(ScriptRunCtxCreate, ctx);
987987
REGISTER_API(ScriptRunCtxAddInput, ctx);
988+
REGISTER_API(ScriptRunCtxAddInputList, ctx);
988989
REGISTER_API(ScriptRunCtxAddOutput, ctx);
989990
REGISTER_API(ScriptRunCtxNumOutputs, ctx);
990991
REGISTER_API(ScriptRunCtxOutputTensor, ctx);

src/redisai.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ int MODULE_API_FUNC(RedisAI_TensorNumDims)(RAI_Tensor* t);
7878
long long MODULE_API_FUNC(RedisAI_TensorDim)(RAI_Tensor* t, int dim);
7979
size_t MODULE_API_FUNC(RedisAI_TensorByteSize)(RAI_Tensor* t);
8080
char* MODULE_API_FUNC(RedisAI_TensorData)(RAI_Tensor* t);
81-
RedisModuleType MODULE_API_FUNC(RedisAI_TensorRedisType)(void);
81+
RedisModuleType* MODULE_API_FUNC(RedisAI_TensorRedisType)(void);
8282

8383
RAI_Model* MODULE_API_FUNC(RedisAI_ModelCreate)(int backend, char* devicestr, char* tag, RAI_ModelOpts opts,
8484
size_t ninputs, const char **inputs,
@@ -94,19 +94,20 @@ void MODULE_API_FUNC(RedisAI_ModelRunCtxFree)(RAI_ModelRunCtx* mctx);
9494
int MODULE_API_FUNC(RedisAI_ModelRun)(RAI_ModelRunCtx** mctx, long long n, RAI_Error* err);
9595
RAI_Model* MODULE_API_FUNC(RedisAI_ModelGetShallowCopy)(RAI_Model* model);
9696
int MODULE_API_FUNC(RedisAI_ModelSerialize)(RAI_Model *model, char **buffer, size_t *len, RAI_Error *err);
97-
RedisModuleType MODULE_API_FUNC(RedisAI_ModelRedisType)(void);
97+
RedisModuleType* MODULE_API_FUNC(RedisAI_ModelRedisType)(void);
9898

9999
RAI_Script* MODULE_API_FUNC(RedisAI_ScriptCreate)(char* devicestr, char* tag, const char* scriptdef, RAI_Error* err);
100100
void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script* script, RAI_Error* err);
101101
RAI_ScriptRunCtx* MODULE_API_FUNC(RedisAI_ScriptRunCtxCreate)(RAI_Script* script, const char *fnname);
102-
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor);
102+
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInput)(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err);
103+
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddInputList)(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err);
103104
int MODULE_API_FUNC(RedisAI_ScriptRunCtxAddOutput)(RAI_ScriptRunCtx* sctx);
104105
size_t MODULE_API_FUNC(RedisAI_ScriptRunCtxNumOutputs)(RAI_ScriptRunCtx* sctx);
105106
RAI_Tensor* MODULE_API_FUNC(RedisAI_ScriptRunCtxOutputTensor)(RAI_ScriptRunCtx* sctx, size_t index);
106107
void MODULE_API_FUNC(RedisAI_ScriptRunCtxFree)(RAI_ScriptRunCtx* sctx);
107108
int MODULE_API_FUNC(RedisAI_ScriptRun)(RAI_ScriptRunCtx* sctx, RAI_Error* err);
108109
RAI_Script* MODULE_API_FUNC(RedisAI_ScriptGetShallowCopy)(RAI_Script* script);
109-
RedisModuleType MODULE_API_FUNC(RedisAI_ScriptRedisType)(void);
110+
RedisModuleType* MODULE_API_FUNC(RedisAI_ScriptRedisType)(void);
110111

111112
int MODULE_API_FUNC(RedisAI_GetLLAPIVersion)();
112113

@@ -167,6 +168,7 @@ static int RedisAI_Initialize(RedisModuleCtx* ctx){
167168
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptFree);
168169
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxCreate);
169170
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInput);
171+
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddInputList);
170172
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxAddOutput);
171173
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxNumOutputs);
172174
REDISAI_MODULE_INIT_FUNCTION(ctx, ScriptRunCtxOutputTensor);

src/script.c

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,27 @@ static int Script_RunCtxAddParam(RAI_ScriptRunCtx* sctx,
164164
return 1;
165165
}
166166

167-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor) {
168-
return Script_RunCtxAddParam(sctx, sctx->inputs, inputTensor);
167+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err) {
168+
if(sctx->variadic != -1) {
169+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors");
170+
return 0;
171+
}
172+
return Script_RunCtxAddParam(sctx, sctx->inputs, inputTensor);
173+
}
174+
175+
int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err) {
176+
// If this is the first time a list is added, set the variadic, else return an error.
177+
if(sctx->variadic == -1) {
178+
sctx->variadic = array_len(sctx->inputs);
179+
}
180+
else {
181+
RAI_SetError(err, RAI_EBACKENDNOTLOADED, "ERR Already encountered a variable size list of tensors");
182+
return 0;
183+
}
184+
for(size_t i=0; i < len; i++){
185+
Script_RunCtxAddParam(sctx, sctx->inputs, inputTensors[i]);
186+
}
187+
return 1;
169188
}
170189

171190
int RAI_ScriptRunCtxAddOutput(RAI_ScriptRunCtx* sctx) {
@@ -270,7 +289,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
270289
int is_input = 0;
271290
int outputs_flag_count = 0;
272291
size_t argpos = 4;
273-
292+
// Keep variadic local variable as the calls for RAI_ScriptRunCtxAddInput check if (*sctx)->variadic already assigned.
293+
size_t variadic = (*sctx)->variadic;
274294
for (; argpos <= argc - 1; argpos++) {
275295
const char *arg_string = RedisModule_StringPtrLen(argv[argpos], NULL);
276296
if(!arg_string){
@@ -287,7 +307,11 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
287307
outputs_flag_count = 1;
288308
} else {
289309
if (!strcasecmp(arg_string, "$")) {
290-
(*sctx)->variadic = argpos - 4;
310+
if(variadic > -1) {
311+
RedisAI_ReplyOrSetError(ctx,error,RAI_ESCRIPTRUN, "ERR Already encountered a variable size list of tensors");
312+
return -1;
313+
}
314+
variadic = argpos - 4;
291315
continue;
292316
}
293317
RedisModule_RetainString(ctx, argv[argpos]);
@@ -309,10 +333,7 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
309333
return -1;
310334
}
311335
}
312-
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor)) {
313-
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Input key not found");
314-
return -1;
315-
}
336+
if (!RAI_ScriptRunCtxAddInput(*sctx, inputTensor, error)) return -1;
316337
} else {
317338
if (!RAI_ScriptRunCtxAddOutput(*sctx)) {
318339
RedisAI_ReplyOrSetError(ctx, error, RAI_ESCRIPTRUN, "ERR Output key not found");
@@ -322,6 +343,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
322343
}
323344
}
324345
}
346+
// In case variadic position found, set it in the context.
347+
(*sctx)->variadic = variadic;
325348
return argpos;
326349
}
327350

src/script.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,25 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
6767
*
6868
* @param sctx input RAI_ScriptRunCtx to add the input tensor
6969
* @param inputTensor input tensor structure
70+
* @param err error data structure to store error message in the case of
71+
* failures
72+
* @return returns 1 on success ( always returns success )
73+
*/
74+
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor, RAI_Error* err);
75+
76+
/**
77+
* For each Allocates a RAI_ScriptCtxParam data structure, and enforces a shallow copy of
78+
* the provided input tensor, adding it to the input tensors array of the
79+
* RAI_ScriptRunCtx.
80+
*
81+
* @param sctx input RAI_ScriptRunCtx to add the input tensor
82+
* @param inputTensors input tensors array
83+
* @param len input tensors array len
84+
* @param err error data structure to store error message in the case of
85+
* failures
7086
* @return returns 1 on success ( always returns success )
7187
*/
72-
int RAI_ScriptRunCtxAddInput(RAI_ScriptRunCtx* sctx, RAI_Tensor* inputTensor);
88+
int RAI_ScriptRunCtxAddInputList(RAI_ScriptRunCtx* sctx, RAI_Tensor** inputTensors, size_t len, RAI_Error* err);
7389

7490
/**
7591
* Allocates a RAI_ScriptCtxParam data structure, and sets the tensor reference

test/tests_pytorch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,13 @@ def test_pytorch_scriptrun_errors(env):
641641
except Exception as e:
642642
exception = e
643643
env.assertEqual(type(exception), redis.exceptions.ResponseError)
644+
645+
# "ERR Already encountered a variable size list of tensors"
646+
try:
647+
con.execute_command('AI.SCRIPTRUN', 'ket', 'bar_variadic', 'INPUTS', '$', 'a', '$', 'b' 'OUTPUTS')
648+
except Exception as e:
649+
exception = e
650+
env.assertEqual(type(exception), redis.exceptions.ResponseError)
644651

645652

646653
def test_pytorch_scriptinfo(env):

0 commit comments

Comments
 (0)