Skip to content

Commit f3fc135

Browse files
committed
Reuse memory allocated for argv in TENSORSET
1 parent f707a56 commit f3fc135

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

src/dag.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,7 +1189,6 @@ static int DAG_CommandParser(RedisModuleCtx *ctx, RedisModuleString **argv, int
11891189
if (!mangled_entry) {
11901190
AI_dictRelease(mangled_tensors);
11911191
AI_dictRelease(mangled_persisted);
1192-
RedisModule_ReplyWithError(ctx, "ERR PERSIST key cannot be found in DAG");
11931192
AI_dictReleaseIterator(iter);
11941193
RedisModule_ReplyWithError(ctx, "ERR PERSIST key cannot be found in DAG");
11951194
return REDISMODULE_ERR;

src/tensor.c

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,61 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
289289
return ret;
290290
}
291291

292+
void RAI_RStringDataTensorDeleter(DLManagedTensor *arg) {
293+
if (arg->dl_tensor.shape) {
294+
RedisModule_Free(arg->dl_tensor.shape);
295+
}
296+
if (arg->dl_tensor.strides) {
297+
RedisModule_Free(arg->dl_tensor.strides);
298+
}
299+
if (arg->manager_ctx) {
300+
RedisModuleString *rstr = (RedisModuleString *)arg->manager_ctx;
301+
RedisModule_FreeString(NULL, rstr);
302+
}
303+
304+
RedisModule_Free(arg);
305+
}
306+
307+
RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long *dims, int ndims,
308+
RedisModuleString *rstr) {
309+
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
310+
if (dtypeSize == 0) {
311+
return NULL;
312+
}
313+
314+
RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
315+
int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape));
316+
int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides));
317+
318+
size_t len = 1;
319+
for (int64_t i = 0; i < ndims; ++i) {
320+
shape[i] = dims[i];
321+
strides[i] = 1;
322+
len *= dims[i];
323+
}
324+
for (int64_t i = ndims - 2; i >= 0; --i) {
325+
strides[i] *= strides[i + 1] * shape[i + 1];
326+
}
327+
328+
DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0};
329+
330+
char *data = (char *)RedisModule_StringPtrLen(rstr, NULL);
331+
332+
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
333+
.data = data,
334+
.ndim = ndims,
335+
.dtype = dtype,
336+
.shape = shape,
337+
.strides = strides,
338+
.byte_offset = 0},
339+
.manager_ctx = rstr,
340+
.deleter = RAI_RStringDataTensorDeleter };
341+
342+
ret->refCount = 1;
343+
return ret;
344+
}
345+
346+
292347
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata) {
293348
DLDataType dtype = RAI_TensorDataTypeFromString(dataType);
294349
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
@@ -814,7 +869,15 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
814869
size_t datalen;
815870
const char *data;
816871
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
817-
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
872+
if (datafmt == REDISAI_DATA_BLOB) {
873+
RedisModuleString *rstr = argv[argpos];
874+
RedisModule_RetainString(NULL, rstr);
875+
*t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr);
876+
}
877+
else {
878+
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
879+
}
880+
818881
if (!t) {
819882
array_free(dims);
820883
if (ctx == NULL) {
@@ -826,22 +889,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
826889
}
827890
long i = 0;
828891
switch (datafmt) {
829-
case REDISAI_DATA_BLOB: {
830-
const char *blob = RedisModule_StringPtrLen(argv[argpos], &datalen);
831-
if (datalen != nbytes) {
832-
RAI_TensorFree(*t);
833-
array_free(dims);
834-
if (ctx == NULL) {
835-
RAI_SetError(error, RAI_ETENSORSET,
836-
"ERR data length does not match tensor shape and type");
837-
} else {
838-
RedisModule_ReplyWithError(ctx,
839-
"ERR data length does not match tensor shape and type");
840-
}
841-
return -1;
842-
}
843-
RAI_TensorSetData(*t, blob, datalen);
844-
} break;
845892
case REDISAI_DATA_VALUES:
846893
for (; (argpos <= argc - 1) && (i < len); argpos++) {
847894
if (datatype.code == kDLFloat) {

0 commit comments

Comments
 (0)