@@ -46,14 +46,18 @@ static size_t Tensor_DataTypeSize(DLDataType dtype) {
46
46
return dtype .bits / 8 ;
47
47
}
48
48
49
- static void Tensor_DataTypeStr (DLDataType dtype , char * * dtypestr ) {
49
+ void Tensor_DataTypeStr (DLDataType dtype , char * * dtypestr ) {
50
50
* dtypestr = RedisModule_Calloc (8 , sizeof (char ));
51
51
if (dtype .code == kDLFloat ) {
52
52
if (dtype .bits == 32 ) {
53
- strcpy (* dtypestr , "FLOAT32 " );
53
+ strcpy (* dtypestr , "FLOAT " );
54
54
}
55
55
else if (dtype .bits == 64 ) {
56
- strcpy (* dtypestr , "FLOAT64" );
56
+ strcpy (* dtypestr , "DOUBLE" );
57
+ }
58
+ else {
59
+ RedisModule_Free (* dtypestr );
60
+ * dtypestr = NULL ;
57
61
}
58
62
}
59
63
else if (dtype .code == kDLInt ) {
@@ -69,6 +73,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
69
73
else if (dtype .bits == 64 ) {
70
74
strcpy (* dtypestr , "INT64" );
71
75
}
76
+ else {
77
+ RedisModule_Free (* dtypestr );
78
+ * dtypestr = NULL ;
79
+ }
72
80
}
73
81
else if (dtype .code == kDLUInt ) {
74
82
if (dtype .bits == 8 ) {
@@ -77,6 +85,10 @@ static void Tensor_DataTypeStr(DLDataType dtype, char **dtypestr) {
77
85
else if (dtype .bits == 16 ) {
78
86
strcpy (* dtypestr , "UINT16" );
79
87
}
88
+ else {
89
+ RedisModule_Free (* dtypestr );
90
+ * dtypestr = NULL ;
91
+ }
80
92
}
81
93
}
82
94
@@ -175,51 +187,21 @@ static void RAI_Tensor_AofRewrite(RedisModuleIO *aof, RedisModuleString *key, vo
175
187
RAI_Tensor * tensor = (RAI_Tensor * )value ;
176
188
177
189
char * dtypestr = NULL ;
178
-
179
190
Tensor_DataTypeStr (RAI_TensorDataType (tensor ), & dtypestr );
180
191
181
- int64_t * shape = tensor -> tensor .dl_tensor .shape ;
182
- char * data = RAI_TensorData (tensor );
183
- size_t size = RAI_TensorByteSize (tensor );
192
+ char * data = RAI_TensorData (tensor );
193
+ long long size = RAI_TensorByteSize (tensor );
194
+
195
+ long long ndims = RAI_TensorNumDims (tensor );
196
+
197
+ RedisModuleString * dims [ndims ];
184
198
185
- // We switch over the dimensions of the tensor up to 7
186
- // The reason is that we don't have a way to pass a vector of long long to RedisModule_EmitAOF,
187
- // there's no format for it. Vector of strings is supported (format 'v').
188
- // This might change in the future, but it needs to change in redis/src/module.c
189
-
190
- switch (RAI_TensorNumDims (tensor )) {
191
- case 1 :
192
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllcb" ,
193
- key , dtypestr , RAI_SPLICE_SHAPE_1 (shape ), "BLOB" , data , size );
194
- break ;
195
- case 2 :
196
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllcb" ,
197
- key , dtypestr , RAI_SPLICE_SHAPE_2 (shape ), "BLOB" , data , size );
198
- break ;
199
- case 3 :
200
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllcb" ,
201
- key , dtypestr , RAI_SPLICE_SHAPE_3 (shape ), "BLOB" , data , size );
202
- break ;
203
- case 4 :
204
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllllcb" ,
205
- key , dtypestr , RAI_SPLICE_SHAPE_4 (shape ), "BLOB" , data , size );
206
- break ;
207
- case 5 :
208
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllllcb" ,
209
- key , dtypestr , RAI_SPLICE_SHAPE_5 (shape ), "BLOB" , data , size );
210
- break ;
211
- case 6 :
212
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "slllllllcb" ,
213
- key , dtypestr , RAI_SPLICE_SHAPE_6 (shape ), "BLOB" , data , size );
214
- break ;
215
- case 7 :
216
- RedisModule_EmitAOF (aof , "AI.TENSORSET" , "sllllllllcb" ,
217
- key , dtypestr , RAI_SPLICE_SHAPE_7 (shape ), "BLOB" , data , size );
218
- break ;
219
- default :
220
- printf ("ERR: AOF serialization supports tensors of dimension up to 7\n" );
199
+ for (long long i = 0 ; i < ndims ; i ++ ) {
200
+ dims [i ] = RedisModule_CreateStringFromLongLong (RedisModule_GetContextFromIO (aof ), RAI_TensorDim (tensor , i ));
221
201
}
222
202
203
+ RedisModule_EmitAOF (aof , "AI.TENSORSET" , "scvcb" , key , dtypestr , dims , ndims , "BLOB" , data , size );
204
+
223
205
RedisModule_Free (dtypestr );
224
206
}
225
207
0 commit comments