@@ -156,21 +156,40 @@ RAI_ScriptRunCtx* RAI_ScriptRunCtxCreate(RAI_Script* script,
156
156
}
157
157
158
158
static int Script_RunCtxAddParam (RAI_ScriptRunCtx * sctx ,
159
- RAI_ScriptCtxParam * paramArr ,
159
+ RAI_ScriptCtxParam * * paramArr ,
160
160
RAI_Tensor * tensor ) {
161
161
RAI_ScriptCtxParam param = {
162
162
.tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ,
163
163
};
164
- paramArr = array_append (paramArr , param );
164
+ * paramArr = array_append (* paramArr , param );
165
165
return 1 ;
166
166
}
167
167
168
- int RAI_ScriptRunCtxAddInput (RAI_ScriptRunCtx * sctx , RAI_Tensor * inputTensor ) {
169
- return Script_RunCtxAddParam (sctx , sctx -> inputs , inputTensor );
168
+ int RAI_ScriptRunCtxAddInput (RAI_ScriptRunCtx * sctx , RAI_Tensor * inputTensor , RAI_Error * err ) {
169
+ if (sctx -> variadic != -1 ) {
170
+ RAI_SetError (err , RAI_EBACKENDNOTLOADED , "ERR Already encountered a variable size list of tensors" );
171
+ return 0 ;
172
+ }
173
+ return Script_RunCtxAddParam (sctx , & sctx -> inputs , inputTensor );
174
+ }
175
+
176
+ int RAI_ScriptRunCtxAddInputList (RAI_ScriptRunCtx * sctx , RAI_Tensor * * inputTensors , size_t len , RAI_Error * err ) {
177
+ // If this is the first time a list is added, set the variadic, else return an error.
178
+ if (sctx -> variadic == -1 ) {
179
+ sctx -> variadic = array_len (sctx -> inputs );
180
+ }
181
+ else {
182
+ RAI_SetError (err , RAI_EBACKENDNOTLOADED , "ERR Already encountered a variable size list of tensors" );
183
+ return 0 ;
184
+ }
185
+ for (size_t i = 0 ; i < len ; i ++ ){
186
+ Script_RunCtxAddParam (sctx , & sctx -> inputs , inputTensors [i ]);
187
+ }
188
+ return 1 ;
170
189
}
171
190
172
191
int RAI_ScriptRunCtxAddOutput (RAI_ScriptRunCtx * sctx ) {
173
- return Script_RunCtxAddParam (sctx , sctx -> outputs , NULL );
192
+ return Script_RunCtxAddParam (sctx , & sctx -> outputs , NULL );
174
193
}
175
194
176
195
size_t RAI_ScriptRunCtxNumOutputs (RAI_ScriptRunCtx * sctx ) {
@@ -274,7 +293,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
274
293
int is_input = 0 ;
275
294
int outputs_flag_count = 0 ;
276
295
size_t argpos = 4 ;
277
-
296
+ // Keep variadic local variable as the calls for RAI_ScriptRunCtxAddInput check if (*sctx)->variadic already assigned.
297
+ size_t variadic = (* sctx )-> variadic ;
278
298
for (; argpos <= argc - 1 ; argpos ++ ) {
279
299
const char * arg_string = RedisModule_StringPtrLen (argv [argpos ], NULL );
280
300
if (!arg_string ){
@@ -291,7 +311,11 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
291
311
outputs_flag_count = 1 ;
292
312
} else {
293
313
if (!strcasecmp (arg_string , "$" )) {
294
- (* sctx )-> variadic = argpos - 4 ;
314
+ if (variadic > -1 ) {
315
+ RedisAI_ReplyOrSetError (ctx ,error ,RAI_ESCRIPTRUN , "ERR Already encountered a variable size list of tensors" );
316
+ return -1 ;
317
+ }
318
+ variadic = argpos - 4 ;
295
319
continue ;
296
320
}
297
321
RedisModule_RetainString (ctx , argv [argpos ]);
@@ -313,10 +337,7 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
313
337
return -1 ;
314
338
}
315
339
}
316
- if (!RAI_ScriptRunCtxAddInput (* sctx , inputTensor )) {
317
- RedisAI_ReplyOrSetError (ctx , error , RAI_ESCRIPTRUN , "ERR Input key not found" );
318
- return -1 ;
319
- }
340
+ if (!RAI_ScriptRunCtxAddInput (* sctx , inputTensor , error )) return -1 ;
320
341
} else {
321
342
if (!RAI_ScriptRunCtxAddOutput (* sctx )) {
322
343
RedisAI_ReplyOrSetError (ctx , error , RAI_ESCRIPTRUN , "ERR Output key not found" );
@@ -326,6 +347,8 @@ int RedisAI_Parse_ScriptRun_RedisCommand(RedisModuleCtx *ctx,
326
347
}
327
348
}
328
349
}
350
+ // In case variadic position found, set it in the context.
351
+ (* sctx )-> variadic = variadic ;
329
352
return argpos ;
330
353
}
331
354
@@ -338,4 +361,8 @@ void RedisAI_ReplyOrSetError(RedisModuleCtx *ctx, RAI_Error *error, RAI_ErrorCod
338
361
} else {
339
362
RedisModule_ReplyWithError (ctx , errorMessage );
340
363
}
341
- }
364
+ }
365
+
366
+ RedisModuleType * RAI_ScriptRedisType (void ) {
367
+ return RedisAI_ScriptType ;
368
+ }
0 commit comments