45
45
#include "util/dict.h"
46
46
#include "util/queue.h"
47
47
#include "dag_parser.h"
48
+ #include "util/string_utils.h"
48
49
49
50
/**
50
51
* Execution of a TENSORSET DAG step.
@@ -59,7 +60,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
59
60
const int parse_result =
60
61
RAI_parseTensorSetArgs (NULL , currentOp -> argv , currentOp -> argc , & t , 0 , currentOp -> err );
61
62
if (parse_result > 0 ) {
62
- const char * key_string = RedisModule_StringPtrLen ( currentOp -> outkeys [0 ], NULL ) ;
63
+ RedisModuleString * key_string = currentOp -> outkeys [0 ];
63
64
RAI_ContextWriteLock (rinfo );
64
65
AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , t );
65
66
RAI_ContextUnlock (rinfo );
@@ -78,7 +79,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
78
79
* @return
79
80
*/
80
81
void RedisAI_DagRunSession_TensorGet_Step (RedisAI_RunInfo * rinfo , RAI_DagOp * currentOp ) {
81
- const char * key_string = RedisModule_StringPtrLen ( currentOp -> inkeys [0 ], NULL ) ;
82
+ RedisModuleString * key_string = currentOp -> inkeys [0 ];
82
83
RAI_Tensor * t = NULL ;
83
84
RAI_ContextReadLock (rinfo );
84
85
currentOp -> result = RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext , key_string ,
@@ -102,8 +103,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
102
103
for (uint i = 0 ; i < n_inkeys ; i ++ ) {
103
104
RAI_Tensor * inputTensor ;
104
105
const int get_result = RAI_getTensorFromLocalContext (
105
- NULL , rinfo -> dagTensorsContext , RedisModule_StringPtrLen (currentOp -> inkeys [i ], NULL ),
106
- & inputTensor , currentOp -> err );
106
+ NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
107
107
if (get_result == REDISMODULE_ERR ) {
108
108
// We check for this outside the function
109
109
// this check cannot be covered by tests
@@ -141,9 +141,8 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c
141
141
const size_t noutputs = RAI_ModelRunCtxNumOutputs (currentOp -> mctx );
142
142
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
143
143
RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (currentOp -> mctx , outputNumber );
144
- const char * key_string = RedisModule_StringPtrLen (currentOp -> outkeys [outputNumber ], NULL );
145
144
tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
146
- AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
145
+ AI_dictReplace (rinfo -> dagTensorsContext , (void * )currentOp -> outkeys [ outputNumber ] , tensor );
147
146
}
148
147
RAI_ContextUnlock (rinfo );
149
148
}
@@ -244,8 +243,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
244
243
for (uint i = 0 ; i < n_inkeys ; i ++ ) {
245
244
RAI_Tensor * inputTensor ;
246
245
const int get_result = RAI_getTensorFromLocalContext (
247
- NULL , rinfo -> dagTensorsContext , RedisModule_StringPtrLen (currentOp -> inkeys [i ], NULL ),
248
- & inputTensor , currentOp -> err );
246
+ NULL , rinfo -> dagTensorsContext , currentOp -> inkeys [i ], & inputTensor , currentOp -> err );
249
247
if (get_result == REDISMODULE_ERR ) {
250
248
// We check for this outside the function
251
249
// this check cannot be covered by tests
@@ -275,7 +273,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
275
273
const size_t noutputs = RAI_ScriptRunCtxNumOutputs (currentOp -> sctx );
276
274
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
277
275
RAI_Tensor * tensor = RAI_ScriptRunCtxOutputTensor (currentOp -> sctx , outputNumber );
278
- const char * key_string = RedisModule_StringPtrLen ( currentOp -> outkeys [outputNumber ], NULL ) ;
276
+ RedisModuleString * key_string = currentOp -> outkeys [outputNumber ];
279
277
tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
280
278
AI_dictReplace (rinfo -> dagTensorsContext , (void * )key_string , tensor );
281
279
}
@@ -304,8 +302,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
304
302
if (rinfo -> single_op_dag ) {
305
303
input = op -> mctx -> inputs [i ].tensor ;
306
304
} else {
307
- RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext ,
308
- RedisModule_StringPtrLen (op -> inkeys [i ], NULL ), & input ,
305
+ RAI_getTensorFromLocalContext (NULL , rinfo -> dagTensorsContext , op -> inkeys [i ], & input ,
309
306
op -> err );
310
307
}
311
308
// We are expecting input != NULL, because we only reach this function if all inputs
@@ -354,16 +351,14 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
354
351
if (rinfo1 -> single_op_dag == 1 ) {
355
352
input1 = op1 -> mctx -> inputs [i ].tensor ;
356
353
} else {
357
- RAI_getTensorFromLocalContext (NULL , rinfo1 -> dagTensorsContext ,
358
- RedisModule_StringPtrLen (op1 -> inkeys [i ], NULL ), & input1 ,
354
+ RAI_getTensorFromLocalContext (NULL , rinfo1 -> dagTensorsContext , op1 -> inkeys [i ], & input1 ,
359
355
op1 -> err );
360
356
}
361
357
RAI_Tensor * input2 ;
362
358
if (rinfo2 -> single_op_dag == 1 ) {
363
359
input2 = op2 -> mctx -> inputs [i ].tensor ;
364
360
} else {
365
- RAI_getTensorFromLocalContext (NULL , rinfo2 -> dagTensorsContext ,
366
- RedisModule_StringPtrLen (op2 -> inkeys [i ], NULL ), & input2 ,
361
+ RAI_getTensorFromLocalContext (NULL , rinfo2 -> dagTensorsContext , op2 -> inkeys [i ], & input2 ,
367
362
op2 -> err );
368
363
}
369
364
if (input1 == NULL || input2 == NULL ) {
@@ -439,8 +434,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
439
434
RAI_ContextReadLock (rinfo );
440
435
441
436
for (int i = 0 ; i < n_inkeys ; i ++ ) {
442
- if (AI_dictFind (rinfo -> dagTensorsContext ,
443
- RedisModule_StringPtrLen (currentOp_ -> inkeys [i ], NULL )) == NULL ) {
437
+ if (AI_dictFind (rinfo -> dagTensorsContext , currentOp_ -> inkeys [i ]) == NULL ) {
444
438
RAI_ContextUnlock (rinfo );
445
439
* currentOpReady = 0 ;
446
440
return ;
@@ -543,17 +537,22 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
543
537
}
544
538
545
539
static int _StoreTensorInKeySpace (RedisModuleCtx * ctx , RAI_Tensor * tensor ,
546
- const char * persist_key_name , bool mangled_name ) {
540
+ RedisModuleString * persist_key_name , bool mangled_name ) {
541
+
547
542
int ret = REDISMODULE_ERR ;
548
543
RedisModuleKey * key ;
549
- char * demangled_key_name = RedisModule_Strdup (persist_key_name );
550
- if (mangled_name )
551
- demangled_key_name [strlen (persist_key_name ) - 4 ] = 0 ;
552
- RedisModuleString * tensor_keyname =
553
- RedisModule_CreateString (ctx , demangled_key_name , strlen (demangled_key_name ));
544
+ size_t persist_key_len ;
545
+ const char * persist_key_str = RedisModule_StringPtrLen (persist_key_name , & persist_key_len );
546
+
547
+ RedisModuleString * demangled_key_name ;
548
+ if (mangled_name ) {
549
+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len - 4 );
550
+ } else {
551
+ demangled_key_name = RedisModule_CreateString (NULL , persist_key_str , persist_key_len );
552
+ }
553
+
554
554
const int status =
555
- RAI_OpenKey_Tensor (ctx , tensor_keyname , & key , REDISMODULE_READ | REDISMODULE_WRITE );
556
- RedisModule_Free (demangled_key_name );
555
+ RAI_OpenKey_Tensor (ctx , demangled_key_name , & key , REDISMODULE_READ | REDISMODULE_WRITE );
557
556
if (status == REDISMODULE_ERR ) {
558
557
RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
559
558
goto clean_up ;
@@ -565,17 +564,19 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
565
564
}
566
565
}
567
566
ret = REDISMODULE_OK ;
567
+
568
568
clean_up :
569
569
RedisModule_CloseKey (key );
570
- RedisAI_ReplicateTensorSet (ctx , tensor_keyname , tensor );
570
+ RedisAI_ReplicateTensorSet (ctx , demangled_key_name , tensor );
571
+ RedisModule_FreeString (NULL , demangled_key_name );
571
572
return ret ;
572
573
}
573
574
574
575
static void PersistTensors (RedisModuleCtx * ctx , RedisAI_RunInfo * rinfo ) {
575
576
AI_dictIterator * persist_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
576
577
AI_dictEntry * persist_entry = AI_dictNext (persist_iter );
577
578
while (persist_entry ) {
578
- const char * persist_key_name = AI_dictGetKey (persist_entry );
579
+ RedisModuleString * persist_key_name = AI_dictGetKey (persist_entry );
579
580
580
581
AI_dictEntry * tensor_entry = AI_dictFind (rinfo -> dagTensorsContext , persist_key_name );
581
582
@@ -586,8 +587,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
586
587
persist_entry = AI_dictNext (persist_iter );
587
588
continue ;
588
589
}
589
- bool mangled = true;
590
- if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , mangled ) == REDISMODULE_ERR )
590
+ if (_StoreTensorInKeySpace (ctx , tensor , persist_key_name , true) == REDISMODULE_ERR )
591
591
rinfo -> dagReplyLength ++ ;
592
592
593
593
} else {
@@ -602,7 +602,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
602
602
AI_dictIterator * local_iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
603
603
AI_dictEntry * local_entry = AI_dictNext (local_iter );
604
604
while (local_entry ) {
605
- const char * localcontext_key_name = AI_dictGetKey (local_entry );
605
+ RedisModuleString * localcontext_key_name = AI_dictGetKey (local_entry );
606
606
RedisModule_Log (ctx , "warning" , "DAG's local context key (%s)" ,
607
607
localcontext_key_name );
608
608
local_entry = AI_dictNext (local_iter );
@@ -623,11 +623,9 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
623
623
const size_t noutputs = RAI_ModelRunCtxNumOutputs (op -> mctx );
624
624
for (size_t outputNumber = 0 ; outputNumber < noutputs ; outputNumber ++ ) {
625
625
RAI_Tensor * tensor = RAI_ModelRunCtxOutputTensor (op -> mctx , outputNumber );
626
- const char * key_string = RedisModule_StringPtrLen (op -> outkeys [outputNumber ], NULL );
627
626
tensor = tensor ? RAI_TensorGetShallowCopy (tensor ) : NULL ;
628
- bool mangled = false;
629
627
if (tensor )
630
- _StoreTensorInKeySpace (ctx , tensor , key_string , mangled );
628
+ _StoreTensorInKeySpace (ctx , tensor , op -> outkeys [ outputNumber ], false );
631
629
}
632
630
}
633
631
@@ -693,8 +691,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
693
691
case REDISAI_DAG_CMD_MODELRUN : {
694
692
rinfo -> dagReplyLength ++ ;
695
693
struct RedisAI_RunStats * rstats = NULL ;
696
- const char * runkey = RedisModule_StringPtrLen (currentOp -> runkey , NULL );
697
- RAI_GetRunStats (runkey , & rstats );
694
+ RAI_GetRunStats (currentOp -> runkey , & rstats );
698
695
if (currentOp -> result == REDISMODULE_ERR ) {
699
696
RAI_SafeAddDataPoint (rstats , 0 , 1 , 1 , 0 );
700
697
RedisModule_ReplyWithError (ctx , currentOp -> err -> detail_oneline );
@@ -719,8 +716,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
719
716
case REDISAI_DAG_CMD_SCRIPTRUN : {
720
717
rinfo -> dagReplyLength ++ ;
721
718
struct RedisAI_RunStats * rstats = NULL ;
722
- const char * runkey = RedisModule_StringPtrLen (currentOp -> runkey , NULL );
723
- RAI_GetRunStats (runkey , & rstats );
719
+ RAI_GetRunStats (currentOp -> runkey , & rstats );
724
720
if (currentOp -> result == REDISMODULE_ERR ) {
725
721
RAI_SafeAddDataPoint (rstats , 0 , 1 , 1 , 0 );
726
722
RedisModule_ReplyWithError (ctx , currentOp -> err -> detail_oneline );
0 commit comments