Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/backends/tensorflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr,
ret->outputs = outputs_;
ret->refCount = 1;


return ret;
}

Expand All @@ -338,6 +337,8 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
TF_DeleteGraph(model->model);
model->model = NULL;

RedisModule_Free(model->devicestr);

if (model->inputs) {
size_t ninputs = array_len(model->inputs);
for (size_t i=0; i<ninputs; i++) {
Expand Down
6 changes: 3 additions & 3 deletions src/backends/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, const char* devicestr,
RAI_Error *error) {
DLDeviceType dl_device;

RAI_Device device;
int64_t deviceid;
RAI_Device device = RAI_DEVICE_CPU;
int64_t deviceid = 0;

if (!parseDeviceStr(devicestr, &device, &deviceid)) {
RAI_SetError(error, RAI_EMODELCONFIGURE, "ERR unsupported device");
Expand Down Expand Up @@ -154,14 +154,14 @@ RAI_Script *RAI_ScriptCreateTorch(const char* devicestr, const char *scriptdef,
ret->devicestr = RedisModule_Strdup(devicestr);
ret->refCount = 1;


return ret;
}

void RAI_ScriptFreeTorch(RAI_Script* script, RAI_Error* error) {

torchDeallocContext(script->script);
RedisModule_Free(script->scriptdef);
RedisModule_Free(script->devicestr);
RedisModule_Free(script);
}

Expand Down
4 changes: 4 additions & 0 deletions src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ static void RAI_Model_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, voi
"OUTPUTS", outputs_, model->noutputs,
buffer, len);

if (buffer) {
RedisModule_Free(buffer);
}

for (size_t i=0; i<model->ninputs; i++) {
RedisModule_FreeString(ctx, inputs_[i]);
}
Expand Down
15 changes: 14 additions & 1 deletion src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
const int hasdata = !AC_IsAtEnd(&ac);

const char* fmtstr;
int datafmt;
int datafmt = REDISAI_DATA_NONE;
if (hasdata) {
AC_GetString(&ac, &fmtstr, NULL, 0);
if (strcasecmp(fmtstr, "BLOB") == 0) {
Expand All @@ -251,12 +251,14 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
case REDISAI_DATA_BLOB:
AC_GetString(&ac, &data, &datalen, 0);
if (datalen != nbytes){
RAI_TensorFree(t);
return RedisModule_ReplyWithError(ctx, "ERR data length does not match tensor shape and type");
}
RAI_TensorSetData(t, data, datalen);
break;
case REDISAI_DATA_VALUES:
if (argc != len + 4 + ndims){
RAI_TensorFree(t);
return RedisModule_WrongArity(ctx);
}
DLDataType datatype = RAI_TensorDataType(t);
Expand Down Expand Up @@ -749,6 +751,8 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,

RedisModule_ReplyWithStringBuffer(ctx, buffer, len);

RedisModule_Free(buffer);

return REDISMODULE_OK;
}

Expand Down Expand Up @@ -817,6 +821,10 @@ void RedisAI_ReplicateTensorSet(RedisModuleCtx *ctx, RedisModuleString *key, RAI
RedisModule_Replicate(ctx, "AI.TENSORSET", "scvcb", key, dtypestr,
dims, ndims, "BLOB", data, size);

// for (long long i=0; i<ndims; i++) {
// RedisModule_Free(dims[i]);
// }

RedisModule_Free(dtypestr);
}

Expand Down Expand Up @@ -1174,23 +1182,27 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
int type = RedisModule_KeyType(argkey);
if (type == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(argkey);
RAI_ScriptRunCtxFree(sctx);
return RedisModule_ReplyWithError(ctx, "Input key is empty");
}
if (!(type == REDISMODULE_KEYTYPE_MODULE &&
RedisModule_ModuleTypeGetType(argkey) == RedisAI_TensorType)) {
RedisModule_CloseKey(argkey);
RAI_ScriptRunCtxFree(sctx);
return RedisModule_ReplyWithError(ctx, REDISMODULE_ERRORMSG_WRONGTYPE);
}
RAI_Tensor *t = RedisModule_ModuleTypeGetValue(argkey);
RedisModule_CloseKey(argkey);
if (!RAI_ScriptRunCtxAddInput(sctx, t)) {
RAI_ScriptRunCtxFree(sctx);
return RedisModule_ReplyWithError(ctx, "Input key not found.");
}
}

outkeys = RedisModule_Calloc(noutputs, sizeof(RedisModuleString*));
for (size_t i=0; i<noutputs; i++) {
if (!RAI_ScriptRunCtxAddOutput(sctx)) {
RAI_ScriptRunCtxFree(sctx);
return RedisModule_ReplyWithError(ctx, "Output key not found.");
}
RedisModule_RetainString(ctx, outputs[i]);
Expand All @@ -1207,6 +1219,7 @@ int RedisAI_ScriptRun_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
AI_dictEntry *entry = AI_dictFind(run_queues, sto->devicestr);
RunQueueInfo *run_queue_info = NULL;
if (!entry){
RAI_ScriptRunCtxFree(sctx);
return RedisModule_ReplyWithError(ctx, "Queue not initialized for device.");
}
else{
Expand Down