Skip to content

Commit 65787a1

Browse files
authored
Reuse memory in TENSORSET (#540)
* Reuse memory allocated for argv in TENSORSET * clang formatting * Remove useless switch
1 parent 3f9d4f0 commit 65787a1

File tree

1 file changed

+63
-23
lines changed

1 file changed

+63
-23
lines changed

src/tensor.c

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,60 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
289289
return ret;
290290
}
291291

292+
void RAI_RStringDataTensorDeleter(DLManagedTensor *arg) {
293+
if (arg->dl_tensor.shape) {
294+
RedisModule_Free(arg->dl_tensor.shape);
295+
}
296+
if (arg->dl_tensor.strides) {
297+
RedisModule_Free(arg->dl_tensor.strides);
298+
}
299+
if (arg->manager_ctx) {
300+
RedisModuleString *rstr = (RedisModuleString *)arg->manager_ctx;
301+
RedisModule_FreeString(NULL, rstr);
302+
}
303+
304+
RedisModule_Free(arg);
305+
}
306+
307+
RAI_Tensor *RAI_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, long long *dims, int ndims,
308+
RedisModuleString *rstr) {
309+
const size_t dtypeSize = Tensor_DataTypeSize(dtype);
310+
if (dtypeSize == 0) {
311+
return NULL;
312+
}
313+
314+
RAI_Tensor *ret = RedisModule_Alloc(sizeof(*ret));
315+
int64_t *shape = RedisModule_Alloc(ndims * sizeof(*shape));
316+
int64_t *strides = RedisModule_Alloc(ndims * sizeof(*strides));
317+
318+
size_t len = 1;
319+
for (int64_t i = 0; i < ndims; ++i) {
320+
shape[i] = dims[i];
321+
strides[i] = 1;
322+
len *= dims[i];
323+
}
324+
for (int64_t i = ndims - 2; i >= 0; --i) {
325+
strides[i] *= strides[i + 1] * shape[i + 1];
326+
}
327+
328+
DLContext ctx = (DLContext){.device_type = kDLCPU, .device_id = 0};
329+
330+
char *data = (char *)RedisModule_StringPtrLen(rstr, NULL);
331+
332+
ret->tensor = (DLManagedTensor){.dl_tensor = (DLTensor){.ctx = ctx,
333+
.data = data,
334+
.ndim = ndims,
335+
.dtype = dtype,
336+
.shape = shape,
337+
.strides = strides,
338+
.byte_offset = 0},
339+
.manager_ctx = rstr,
340+
.deleter = RAI_RStringDataTensorDeleter};
341+
342+
ret->refCount = 1;
343+
return ret;
344+
}
345+
292346
RAI_Tensor *RAI_TensorCreate(const char *dataType, long long *dims, int ndims, int hasdata) {
293347
DLDataType dtype = RAI_TensorDataTypeFromString(dataType);
294348
return RAI_TensorCreateWithDLDataType(dtype, dims, ndims, TENSORALLOC_ALLOC);
@@ -815,7 +869,14 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
815869
size_t datalen;
816870
const char *data;
817871
DLDataType datatype = RAI_TensorDataTypeFromString(typestr);
818-
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
872+
if (datafmt == REDISAI_DATA_BLOB) {
873+
RedisModuleString *rstr = argv[argpos];
874+
RedisModule_RetainString(NULL, rstr);
875+
*t = RAI_TensorCreateWithDLDataTypeAndRString(datatype, dims, ndims, rstr);
876+
} else {
877+
*t = RAI_TensorCreateWithDLDataType(datatype, dims, ndims, tensorAllocMode);
878+
}
879+
819880
if (!t) {
820881
array_free(dims);
821882
if (ctx == NULL) {
@@ -826,24 +887,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
826887
return -1;
827888
}
828889
long i = 0;
829-
switch (datafmt) {
830-
case REDISAI_DATA_BLOB: {
831-
const char *blob = RedisModule_StringPtrLen(argv[argpos], &datalen);
832-
if (datalen != nbytes) {
833-
RAI_TensorFree(*t);
834-
array_free(dims);
835-
if (ctx == NULL) {
836-
RAI_SetError(error, RAI_ETENSORSET,
837-
"ERR data length does not match tensor shape and type");
838-
} else {
839-
RedisModule_ReplyWithError(ctx,
840-
"ERR data length does not match tensor shape and type");
841-
}
842-
return -1;
843-
}
844-
RAI_TensorSetData(*t, blob, datalen);
845-
} break;
846-
case REDISAI_DATA_VALUES:
890+
if (datafmt == REDISAI_DATA_VALUES) {
847891
for (; (argpos <= argc - 1) && (i < len); argpos++) {
848892
if (datatype.code == kDLFloat) {
849893
double val;
@@ -900,10 +944,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
900944
}
901945
i++;
902946
}
903-
break;
904-
default:
905-
// default does not require tensor data setting since calloc setted it to 0
906-
break;
907947
}
908948
array_free(dims);
909949
return argpos;

0 commit comments

Comments
 (0)