Skip to content

Commit fa2a334

Browse files
committed
Merge with master
2 parents 34c4e3e + e06c663 commit fa2a334

23 files changed

+277
-268
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ endif()
55
ADD_LIBRARY(redisai_obj OBJECT
66
util/dict.c
77
util/queue.c
8+
util/string_utils.c
89
redisai.c
910
command_parser.c
1011
run_info.c

src/DAG/dag.c

Lines changed: 33 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "util/dict.h"
4646
#include "util/queue.h"
4747
#include "dag_parser.h"
48+
#include "util/string_utils.h"
4849

4950
/**
5051
* Execution of a TENSORSET DAG step.
@@ -59,7 +60,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
5960
const int parse_result =
6061
RAI_parseTensorSetArgs(NULL, currentOp->argv, currentOp->argc, &t, 0, currentOp->err);
6162
if (parse_result > 0) {
62-
const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[0], NULL);
63+
RedisModuleString *key_string = currentOp->outkeys[0];
6364
RAI_ContextWriteLock(rinfo);
6465
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, t);
6566
RAI_ContextUnlock(rinfo);
@@ -78,7 +79,7 @@ void RedisAI_DagRunSession_TensorSet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
7879
* @return
7980
*/
8081
void RedisAI_DagRunSession_TensorGet_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *currentOp) {
81-
const char *key_string = RedisModule_StringPtrLen(currentOp->inkeys[0], NULL);
82+
RedisModuleString *key_string = currentOp->inkeys[0];
8283
RAI_Tensor *t = NULL;
8384
RAI_ContextReadLock(rinfo);
8485
currentOp->result = RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, key_string,
@@ -102,8 +103,7 @@ static void Dag_LoadInputsToModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *curre
102103
for (uint i = 0; i < n_inkeys; i++) {
103104
RAI_Tensor *inputTensor;
104105
const int get_result = RAI_getTensorFromLocalContext(
105-
NULL, rinfo->dagTensorsContext, RedisModule_StringPtrLen(currentOp->inkeys[i], NULL),
106-
&inputTensor, currentOp->err);
106+
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
107107
if (get_result == REDISMODULE_ERR) {
108108
// We check for this outside the function
109109
// this check cannot be covered by tests
@@ -141,9 +141,8 @@ static void Dag_StoreOutputsFromModelRunCtx(RedisAI_RunInfo *rinfo, RAI_DagOp *c
141141
const size_t noutputs = RAI_ModelRunCtxNumOutputs(currentOp->mctx);
142142
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
143143
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(currentOp->mctx, outputNumber);
144-
const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[outputNumber], NULL);
145144
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
146-
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
145+
AI_dictReplace(rinfo->dagTensorsContext, (void *)currentOp->outkeys[outputNumber], tensor);
147146
}
148147
RAI_ContextUnlock(rinfo);
149148
}
@@ -244,8 +243,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
244243
for (uint i = 0; i < n_inkeys; i++) {
245244
RAI_Tensor *inputTensor;
246245
const int get_result = RAI_getTensorFromLocalContext(
247-
NULL, rinfo->dagTensorsContext, RedisModule_StringPtrLen(currentOp->inkeys[i], NULL),
248-
&inputTensor, currentOp->err);
246+
NULL, rinfo->dagTensorsContext, currentOp->inkeys[i], &inputTensor, currentOp->err);
249247
if (get_result == REDISMODULE_ERR) {
250248
// We check for this outside the function
251249
// this check cannot be covered by tests
@@ -275,7 +273,7 @@ void RedisAI_DagRunSession_ScriptRun_Step(RedisAI_RunInfo *rinfo, RAI_DagOp *cur
275273
const size_t noutputs = RAI_ScriptRunCtxNumOutputs(currentOp->sctx);
276274
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
277275
RAI_Tensor *tensor = RAI_ScriptRunCtxOutputTensor(currentOp->sctx, outputNumber);
278-
const char *key_string = RedisModule_StringPtrLen(currentOp->outkeys[outputNumber], NULL);
276+
RedisModuleString *key_string = currentOp->outkeys[outputNumber];
279277
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
280278
AI_dictReplace(rinfo->dagTensorsContext, (void *)key_string, tensor);
281279
}
@@ -304,8 +302,7 @@ size_t RAI_DagOpBatchSize(RAI_DagOp *op, RedisAI_RunInfo *rinfo) {
304302
if (rinfo->single_op_dag) {
305303
input = op->mctx->inputs[i].tensor;
306304
} else {
307-
RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext,
308-
RedisModule_StringPtrLen(op->inkeys[i], NULL), &input,
305+
RAI_getTensorFromLocalContext(NULL, rinfo->dagTensorsContext, op->inkeys[i], &input,
309306
op->err);
310307
}
311308
// We are expecting input != NULL, because we only reach this function if all inputs
@@ -354,16 +351,14 @@ int RAI_DagOpBatchable(RAI_DagOp *op1, RedisAI_RunInfo *rinfo1, RAI_DagOp *op2,
354351
if (rinfo1->single_op_dag == 1) {
355352
input1 = op1->mctx->inputs[i].tensor;
356353
} else {
357-
RAI_getTensorFromLocalContext(NULL, rinfo1->dagTensorsContext,
358-
RedisModule_StringPtrLen(op1->inkeys[i], NULL), &input1,
354+
RAI_getTensorFromLocalContext(NULL, rinfo1->dagTensorsContext, op1->inkeys[i], &input1,
359355
op1->err);
360356
}
361357
RAI_Tensor *input2;
362358
if (rinfo2->single_op_dag == 1) {
363359
input2 = op2->mctx->inputs[i].tensor;
364360
} else {
365-
RAI_getTensorFromLocalContext(NULL, rinfo2->dagTensorsContext,
366-
RedisModule_StringPtrLen(op2->inkeys[i], NULL), &input2,
361+
RAI_getTensorFromLocalContext(NULL, rinfo2->dagTensorsContext, op2->inkeys[i], &input2,
367362
op2->err);
368363
}
369364
if (input1 == NULL || input2 == NULL) {
@@ -439,8 +434,7 @@ void RedisAI_DagCurrentOpInfo(RedisAI_RunInfo *rinfo, int *currentOpReady,
439434
RAI_ContextReadLock(rinfo);
440435

441436
for (int i = 0; i < n_inkeys; i++) {
442-
if (AI_dictFind(rinfo->dagTensorsContext,
443-
RedisModule_StringPtrLen(currentOp_->inkeys[i], NULL)) == NULL) {
437+
if (AI_dictFind(rinfo->dagTensorsContext, currentOp_->inkeys[i]) == NULL) {
444438
RAI_ContextUnlock(rinfo);
445439
*currentOpReady = 0;
446440
return;
@@ -543,17 +537,22 @@ void RedisAI_BatchedDagRunSessionStep(RedisAI_RunInfo **batched_rinfo, const cha
543537
}
544538

545539
static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
546-
const char *persist_key_name, bool mangled_name) {
540+
RedisModuleString *persist_key_name, bool mangled_name) {
541+
547542
int ret = REDISMODULE_ERR;
548543
RedisModuleKey *key;
549-
char *demangled_key_name = RedisModule_Strdup(persist_key_name);
550-
if (mangled_name)
551-
demangled_key_name[strlen(persist_key_name) - 4] = 0;
552-
RedisModuleString *tensor_keyname =
553-
RedisModule_CreateString(ctx, demangled_key_name, strlen(demangled_key_name));
544+
size_t persist_key_len;
545+
const char *persist_key_str = RedisModule_StringPtrLen(persist_key_name, &persist_key_len);
546+
547+
RedisModuleString *demangled_key_name;
548+
if (mangled_name) {
549+
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len - 4);
550+
} else {
551+
demangled_key_name = RedisModule_CreateString(NULL, persist_key_str, persist_key_len);
552+
}
553+
554554
const int status =
555-
RAI_OpenKey_Tensor(ctx, tensor_keyname, &key, REDISMODULE_READ | REDISMODULE_WRITE);
556-
RedisModule_Free(demangled_key_name);
555+
RAI_OpenKey_Tensor(ctx, demangled_key_name, &key, REDISMODULE_READ | REDISMODULE_WRITE);
557556
if (status == REDISMODULE_ERR) {
558557
RedisModule_ReplyWithError(ctx, "ERR could not save tensor");
559558
goto clean_up;
@@ -565,17 +564,19 @@ static int _StoreTensorInKeySpace(RedisModuleCtx *ctx, RAI_Tensor *tensor,
565564
}
566565
}
567566
ret = REDISMODULE_OK;
567+
568568
clean_up:
569569
RedisModule_CloseKey(key);
570-
RedisAI_ReplicateTensorSet(ctx, tensor_keyname, tensor);
570+
RedisAI_ReplicateTensorSet(ctx, demangled_key_name, tensor);
571+
RedisModule_FreeString(NULL, demangled_key_name);
571572
return ret;
572573
}
573574

574575
static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
575576
AI_dictIterator *persist_iter = AI_dictGetSafeIterator(rinfo->dagTensorsPersistedContext);
576577
AI_dictEntry *persist_entry = AI_dictNext(persist_iter);
577578
while (persist_entry) {
578-
const char *persist_key_name = AI_dictGetKey(persist_entry);
579+
RedisModuleString *persist_key_name = AI_dictGetKey(persist_entry);
579580

580581
AI_dictEntry *tensor_entry = AI_dictFind(rinfo->dagTensorsContext, persist_key_name);
581582

@@ -586,8 +587,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
586587
persist_entry = AI_dictNext(persist_iter);
587588
continue;
588589
}
589-
bool mangled = true;
590-
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, mangled) == REDISMODULE_ERR)
590+
if (_StoreTensorInKeySpace(ctx, tensor, persist_key_name, true) == REDISMODULE_ERR)
591591
rinfo->dagReplyLength++;
592592

593593
} else {
@@ -602,7 +602,7 @@ static void PersistTensors(RedisModuleCtx *ctx, RedisAI_RunInfo *rinfo) {
602602
AI_dictIterator *local_iter = AI_dictGetSafeIterator(rinfo->dagTensorsContext);
603603
AI_dictEntry *local_entry = AI_dictNext(local_iter);
604604
while (local_entry) {
605-
const char *localcontext_key_name = AI_dictGetKey(local_entry);
605+
RedisModuleString *localcontext_key_name = AI_dictGetKey(local_entry);
606606
RedisModule_Log(ctx, "warning", "DAG's local context key (%s)",
607607
localcontext_key_name);
608608
local_entry = AI_dictNext(local_iter);
@@ -623,11 +623,9 @@ static void ModelSingleOp_PersistTensors(RedisModuleCtx *ctx, RAI_DagOp *op) {
623623
const size_t noutputs = RAI_ModelRunCtxNumOutputs(op->mctx);
624624
for (size_t outputNumber = 0; outputNumber < noutputs; outputNumber++) {
625625
RAI_Tensor *tensor = RAI_ModelRunCtxOutputTensor(op->mctx, outputNumber);
626-
const char *key_string = RedisModule_StringPtrLen(op->outkeys[outputNumber], NULL);
627626
tensor = tensor ? RAI_TensorGetShallowCopy(tensor) : NULL;
628-
bool mangled = false;
629627
if (tensor)
630-
_StoreTensorInKeySpace(ctx, tensor, key_string, mangled);
628+
_StoreTensorInKeySpace(ctx, tensor, op->outkeys[outputNumber], false);
631629
}
632630
}
633631

@@ -693,8 +691,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
693691
case REDISAI_DAG_CMD_MODELRUN: {
694692
rinfo->dagReplyLength++;
695693
struct RedisAI_RunStats *rstats = NULL;
696-
const char *runkey = RedisModule_StringPtrLen(currentOp->runkey, NULL);
697-
RAI_GetRunStats(runkey, &rstats);
694+
RAI_GetRunStats(currentOp->runkey, &rstats);
698695
if (currentOp->result == REDISMODULE_ERR) {
699696
RAI_SafeAddDataPoint(rstats, 0, 1, 1, 0);
700697
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);
@@ -719,8 +716,7 @@ int RedisAI_DagRun_Reply(RedisModuleCtx *ctx, RedisModuleString **argv, int argc
719716
case REDISAI_DAG_CMD_SCRIPTRUN: {
720717
rinfo->dagReplyLength++;
721718
struct RedisAI_RunStats *rstats = NULL;
722-
const char *runkey = RedisModule_StringPtrLen(currentOp->runkey, NULL);
723-
RAI_GetRunStats(runkey, &rstats);
719+
RAI_GetRunStats(currentOp->runkey, &rstats);
724720
if (currentOp->result == REDISMODULE_ERR) {
725721
RAI_SafeAddDataPoint(rstats, 0, 1, 1, 0);
726722
RedisModule_ReplyWithError(ctx, currentOp->err->detail_oneline);

0 commit comments

Comments
 (0)