@@ -289,6 +289,61 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
289
289
return ret ;
290
290
}
291
291
292
+ void RAI_RStringDataTensorDeleter (DLManagedTensor * arg ) {
293
+ if (arg -> dl_tensor .shape ) {
294
+ RedisModule_Free (arg -> dl_tensor .shape );
295
+ }
296
+ if (arg -> dl_tensor .strides ) {
297
+ RedisModule_Free (arg -> dl_tensor .strides );
298
+ }
299
+ if (arg -> manager_ctx ) {
300
+ RedisModuleString * rstr = (RedisModuleString * )arg -> manager_ctx ;
301
+ RedisModule_FreeString (NULL , rstr );
302
+ }
303
+
304
+ RedisModule_Free (arg );
305
+ }
306
+
307
+ RAI_Tensor * RAI_TensorCreateWithDLDataTypeAndRString (DLDataType dtype , long long * dims , int ndims ,
308
+ RedisModuleString * rstr ) {
309
+ const size_t dtypeSize = Tensor_DataTypeSize (dtype );
310
+ if (dtypeSize == 0 ) {
311
+ return NULL ;
312
+ }
313
+
314
+ RAI_Tensor * ret = RedisModule_Alloc (sizeof (* ret ));
315
+ int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
316
+ int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
317
+
318
+ size_t len = 1 ;
319
+ for (int64_t i = 0 ; i < ndims ; ++ i ) {
320
+ shape [i ] = dims [i ];
321
+ strides [i ] = 1 ;
322
+ len *= dims [i ];
323
+ }
324
+ for (int64_t i = ndims - 2 ; i >= 0 ; -- i ) {
325
+ strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
326
+ }
327
+
328
+ DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
329
+
330
+ char * data = (char * )RedisModule_StringPtrLen (rstr , NULL );
331
+
332
+ ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
333
+ .data = data ,
334
+ .ndim = ndims ,
335
+ .dtype = dtype ,
336
+ .shape = shape ,
337
+ .strides = strides ,
338
+ .byte_offset = 0 },
339
+ .manager_ctx = rstr ,
340
+ .deleter = RAI_RStringDataTensorDeleter };
341
+
342
+ ret -> refCount = 1 ;
343
+ return ret ;
344
+ }
345
+
346
+
292
347
RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
293
348
DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
294
349
return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
@@ -814,7 +869,15 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
814
869
size_t datalen ;
815
870
const char * data ;
816
871
DLDataType datatype = RAI_TensorDataTypeFromString (typestr );
817
- * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
872
+ if (datafmt == REDISAI_DATA_BLOB ) {
873
+ RedisModuleString * rstr = argv [argpos ];
874
+ RedisModule_RetainString (NULL , rstr );
875
+ * t = RAI_TensorCreateWithDLDataTypeAndRString (datatype , dims , ndims , rstr );
876
+ }
877
+ else {
878
+ * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
879
+ }
880
+
818
881
if (!t ) {
819
882
array_free (dims );
820
883
if (ctx == NULL ) {
@@ -826,22 +889,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
826
889
}
827
890
long i = 0 ;
828
891
switch (datafmt ) {
829
- case REDISAI_DATA_BLOB : {
830
- const char * blob = RedisModule_StringPtrLen (argv [argpos ], & datalen );
831
- if (datalen != nbytes ) {
832
- RAI_TensorFree (* t );
833
- array_free (dims );
834
- if (ctx == NULL ) {
835
- RAI_SetError (error , RAI_ETENSORSET ,
836
- "ERR data length does not match tensor shape and type" );
837
- } else {
838
- RedisModule_ReplyWithError (ctx ,
839
- "ERR data length does not match tensor shape and type" );
840
- }
841
- return -1 ;
842
- }
843
- RAI_TensorSetData (* t , blob , datalen );
844
- } break ;
845
892
case REDISAI_DATA_VALUES :
846
893
for (; (argpos <= argc - 1 ) && (i < len ); argpos ++ ) {
847
894
if (datatype .code == kDLFloat ) {
0 commit comments