@@ -289,6 +289,60 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
289
289
return ret ;
290
290
}
291
291
292
+ void RAI_RStringDataTensorDeleter (DLManagedTensor * arg ) {
293
+ if (arg -> dl_tensor .shape ) {
294
+ RedisModule_Free (arg -> dl_tensor .shape );
295
+ }
296
+ if (arg -> dl_tensor .strides ) {
297
+ RedisModule_Free (arg -> dl_tensor .strides );
298
+ }
299
+ if (arg -> manager_ctx ) {
300
+ RedisModuleString * rstr = (RedisModuleString * )arg -> manager_ctx ;
301
+ RedisModule_FreeString (NULL , rstr );
302
+ }
303
+
304
+ RedisModule_Free (arg );
305
+ }
306
+
307
+ RAI_Tensor * RAI_TensorCreateWithDLDataTypeAndRString (DLDataType dtype , long long * dims , int ndims ,
308
+ RedisModuleString * rstr ) {
309
+ const size_t dtypeSize = Tensor_DataTypeSize (dtype );
310
+ if (dtypeSize == 0 ) {
311
+ return NULL ;
312
+ }
313
+
314
+ RAI_Tensor * ret = RedisModule_Alloc (sizeof (* ret ));
315
+ int64_t * shape = RedisModule_Alloc (ndims * sizeof (* shape ));
316
+ int64_t * strides = RedisModule_Alloc (ndims * sizeof (* strides ));
317
+
318
+ size_t len = 1 ;
319
+ for (int64_t i = 0 ; i < ndims ; ++ i ) {
320
+ shape [i ] = dims [i ];
321
+ strides [i ] = 1 ;
322
+ len *= dims [i ];
323
+ }
324
+ for (int64_t i = ndims - 2 ; i >= 0 ; -- i ) {
325
+ strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
326
+ }
327
+
328
+ DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
329
+
330
+ char * data = (char * )RedisModule_StringPtrLen (rstr , NULL );
331
+
332
+ ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
333
+ .data = data ,
334
+ .ndim = ndims ,
335
+ .dtype = dtype ,
336
+ .shape = shape ,
337
+ .strides = strides ,
338
+ .byte_offset = 0 },
339
+ .manager_ctx = rstr ,
340
+ .deleter = RAI_RStringDataTensorDeleter };
341
+
342
+ ret -> refCount = 1 ;
343
+ return ret ;
344
+ }
345
+
292
346
RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
293
347
DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
294
348
return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
@@ -815,7 +869,14 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
815
869
size_t datalen ;
816
870
const char * data ;
817
871
DLDataType datatype = RAI_TensorDataTypeFromString (typestr );
818
- * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
872
+ if (datafmt == REDISAI_DATA_BLOB ) {
873
+ RedisModuleString * rstr = argv [argpos ];
874
+ RedisModule_RetainString (NULL , rstr );
875
+ * t = RAI_TensorCreateWithDLDataTypeAndRString (datatype , dims , ndims , rstr );
876
+ } else {
877
+ * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
878
+ }
879
+
819
880
if (!t ) {
820
881
array_free (dims );
821
882
if (ctx == NULL ) {
@@ -826,24 +887,7 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
826
887
return -1 ;
827
888
}
828
889
long i = 0 ;
829
- switch (datafmt ) {
830
- case REDISAI_DATA_BLOB : {
831
- const char * blob = RedisModule_StringPtrLen (argv [argpos ], & datalen );
832
- if (datalen != nbytes ) {
833
- RAI_TensorFree (* t );
834
- array_free (dims );
835
- if (ctx == NULL ) {
836
- RAI_SetError (error , RAI_ETENSORSET ,
837
- "ERR data length does not match tensor shape and type" );
838
- } else {
839
- RedisModule_ReplyWithError (ctx ,
840
- "ERR data length does not match tensor shape and type" );
841
- }
842
- return -1 ;
843
- }
844
- RAI_TensorSetData (* t , blob , datalen );
845
- } break ;
846
- case REDISAI_DATA_VALUES :
890
+ if (datafmt == REDISAI_DATA_VALUES ) {
847
891
for (; (argpos <= argc - 1 ) && (i < len ); argpos ++ ) {
848
892
if (datatype .code == kDLFloat ) {
849
893
double val ;
@@ -900,10 +944,6 @@ int RAI_parseTensorSetArgs(RedisModuleCtx *ctx, RedisModuleString **argv, int ar
900
944
}
901
945
i ++ ;
902
946
}
903
- break ;
904
- default :
905
- // default does not require tensor data setting since calloc setted it to 0
906
- break ;
907
947
}
908
948
array_free (dims );
909
949
return argpos ;
0 commit comments