@@ -163,7 +163,7 @@ OrtValue* RAI_OrtValueFromTensors(RAI_Tensor** ts, size_t count, RAI_Error *erro
163
163
return NULL ;
164
164
}
165
165
166
- RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , size_t batch_size , RAI_Error * error ) {
166
+ RAI_Tensor * RAI_TensorCreateFromOrtValue (OrtValue * v , size_t batch_offset , long long batch_size , RAI_Error * error ) {
167
167
OrtStatus * status = NULL ;
168
168
const OrtApi * ort = OrtGetApiBase ()-> GetApi (1 );
169
169
@@ -215,7 +215,12 @@ RAI_Tensor* RAI_TensorCreateFromOrtValue(OrtValue* v, size_t batch_offset, size_
215
215
shape [i ] = dims [i ];
216
216
strides [i ] = 1 ;
217
217
}
218
- shape [0 ] = batch_size ;
218
+ if (batch_size != -1 ) {
219
+ shape [0 ] = batch_size ;
220
+ }
221
+ else {
222
+ batch_size = total_batch_size ;
223
+ }
219
224
for (int64_t i = ndims - 2 ; i >= 0 ; -- i )
220
225
{
221
226
strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
@@ -412,9 +417,11 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
412
417
413
418
size_t batch_sizes [nbatches ];
414
419
size_t batch_offsets [nbatches ];
420
+ size_t total_batch_size = 0 ;
415
421
if (array_len (mctxs [0 ]-> inputs ) > 0 ) {
416
422
for (size_t b = 0 ; b < nbatches ; ++ b ) {
417
423
batch_sizes [b ] = RAI_TensorDim (mctxs [b ]-> inputs [0 ].tensor , 0 );
424
+ total_batch_size += batch_sizes [b ];
418
425
}
419
426
batch_offsets [0 ] = 0 ;
420
427
for (size_t b = 1 ; b < nbatches ; ++ b ) {
@@ -530,14 +537,48 @@ int RAI_ModelRunORT(RAI_ModelRunCtx **mctxs, RAI_Error *error)
530
537
}
531
538
532
539
for (size_t i = 0 ; i < n_output_nodes ; i ++ ) {
533
- for (size_t b = 0 ; b < nbatches ; b ++ ) {
534
- RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
540
+ if (nbatches > 1 ) {
541
+ OrtTensorTypeAndShapeInfo * info ;
542
+ status = ort -> GetTensorTypeAndShape (outputs [i ], & info );
543
+ if (status != NULL ) goto error ;
544
+
545
+ size_t ndims ;
546
+ status = ort -> GetDimensionsCount (info , & ndims );
547
+ if (status != NULL ) goto error ;
548
+
549
+ int64_t dims [ndims ];
550
+ status = ort -> GetDimensions (info , dims , ndims );
551
+ if (status != NULL ) goto error ;
552
+
553
+ if (dims [0 ] != total_batch_size ) {
554
+ RAI_SetError (error , RAI_EMODELRUN , "ERR Model did not generate the expected batch size" );
555
+ ort -> ReleaseStatus (status );
556
+ return 1 ;
557
+ }
558
+
559
+ for (size_t b = 0 ; b < nbatches ; b ++ ) {
560
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], batch_offsets [b ], batch_sizes [b ], error );
561
+ if (error -> code != RAI_OK ) {
562
+ ort -> ReleaseStatus (status );
563
+ return 1 ;
564
+ }
565
+ if (output_tensor ) {
566
+ mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
567
+ RAI_TensorFree (output_tensor );
568
+ }
569
+ else {
570
+ printf ("ERR: non-tensor output from ONNX models, ignoring (currently unsupported)" );
571
+ }
572
+ }
573
+ }
574
+ else {
575
+ RAI_Tensor * output_tensor = RAI_TensorCreateFromOrtValue (outputs [i ], 0 , -1 , error );
535
576
if (error -> code != RAI_OK ) {
536
577
ort -> ReleaseStatus (status );
537
578
return 1 ;
538
579
}
539
580
if (output_tensor ) {
540
- mctxs [b ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
581
+ mctxs [0 ]-> outputs [i ].tensor = RAI_TensorGetShallowCopy (output_tensor );
541
582
RAI_TensorFree (output_tensor );
542
583
}
543
584
else {
0 commit comments