@@ -238,13 +238,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
238
238
239
239
TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
240
240
241
- TF_Buffer * buffer = TF_NewBuffer ();
242
- buffer -> length = modellen ;
243
- buffer -> data = modeldef ;
241
+ TF_Buffer * tfbuffer = TF_NewBuffer ();
242
+ tfbuffer -> length = modellen ;
243
+ tfbuffer -> data = modeldef ;
244
244
245
245
TF_Status * status = TF_NewStatus ();
246
246
247
- TF_GraphImportGraphDef (model , buffer , options , status );
247
+ TF_GraphImportGraphDef (model , tfbuffer , options , status );
248
248
249
249
if (TF_GetCode (status ) != TF_OK ) {
250
250
char * errorMessage = RedisModule_Strdup (TF_Message (status ));
@@ -276,7 +276,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
276
276
}
277
277
278
278
TF_DeleteImportGraphDefOptions (options );
279
- TF_DeleteBuffer (buffer );
279
+ TF_DeleteBuffer (tfbuffer );
280
280
TF_DeleteStatus (status );
281
281
282
282
TF_Status * optionsStatus = TF_NewStatus ();
@@ -394,6 +394,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
394
394
array_append (outputs_ , RedisModule_Strdup (outputs [i ]));
395
395
}
396
396
397
+ char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
398
+ memcpy (buffer , modeldef , modellen );
399
+
397
400
RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
398
401
ret -> model = model ;
399
402
ret -> session = session ;
@@ -403,7 +406,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
403
406
ret -> outputs = outputs_ ;
404
407
ret -> opts = opts ;
405
408
ret -> refCount = 1 ;
406
-
409
+ ret -> data = buffer ;
410
+ ret -> datalen = modellen ;
411
+
407
412
return ret ;
408
413
}
409
414
@@ -445,6 +450,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
445
450
array_free (model -> outputs );
446
451
}
447
452
453
+ if (model -> data ) {
454
+ RedisModule_Free (model -> data );
455
+ }
456
+
448
457
TF_DeleteStatus (status );
449
458
}
450
459
@@ -534,24 +543,32 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
534
543
}
535
544
536
545
int RAI_ModelSerializeTF (RAI_Model * model , char * * buffer , size_t * len , RAI_Error * error ) {
537
- TF_Buffer * tf_buffer = TF_NewBuffer ();
538
- TF_Status * status = TF_NewStatus ();
539
546
540
- TF_GraphToGraphDef (model -> model , tf_buffer , status );
547
+ if (model -> data ) {
548
+ * buffer = RedisModule_Calloc (model -> datalen , sizeof (char ));
549
+ memcpy (* buffer , model -> data , model -> datalen );
550
+ * len = model -> datalen ;
551
+ }
552
+ else {
553
+ TF_Buffer * tf_buffer = TF_NewBuffer ();
554
+ TF_Status * status = TF_NewStatus ();
555
+
556
+ TF_GraphToGraphDef (model -> model , tf_buffer , status );
557
+
558
+ if (TF_GetCode (status ) != TF_OK ) {
559
+ RAI_SetError (error , RAI_EMODELSERIALIZE , "ERR Error serializing TF model" );
560
+ TF_DeleteBuffer (tf_buffer );
561
+ TF_DeleteStatus (status );
562
+ return 1 ;
563
+ }
564
+
565
+ * buffer = RedisModule_Alloc (tf_buffer -> length );
566
+ memcpy (* buffer , tf_buffer -> data , tf_buffer -> length );
567
+ * len = tf_buffer -> length ;
541
568
542
- if (TF_GetCode (status ) != TF_OK ) {
543
- RAI_SetError (error , RAI_EMODELSERIALIZE , "ERR Error serializing TF model" );
544
569
TF_DeleteBuffer (tf_buffer );
545
570
TF_DeleteStatus (status );
546
- return 1 ;
547
571
}
548
572
549
- * buffer = RedisModule_Alloc (tf_buffer -> length );
550
- memcpy (* buffer , tf_buffer -> data , tf_buffer -> length );
551
- * len = tf_buffer -> length ;
552
-
553
- TF_DeleteBuffer (tf_buffer );
554
- TF_DeleteStatus (status );
555
-
556
573
return 0 ;
557
574
}
0 commit comments