@@ -484,8 +484,10 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
484484 private protected override void SaveModel ( ModelSaveContext ctx ) => _parent . SaveModel ( ctx ) ;
485485
486486 protected override Delegate MakeGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , out Action disposer )
487+ => throw new NotImplementedException ( "This should never be called!" ) ;
488+
489+ private Delegate CreateGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , OnnxRuntimeOutputCacher outputCacher )
487490 {
488- disposer = null ;
489491 Host . AssertValue ( input ) ;
490492
491493 var activeOutputColNames = _parent . Outputs . Where ( ( x , i ) => activeOutput ( i ) ) . ToArray ( ) ;
@@ -495,26 +497,59 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
495497 var elemRawType = vectorType . ItemType . RawType ;
496498 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
497499 if ( vectorType . ItemType is TextDataViewType )
498- return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
500+ return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
499501 else
500- return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
502+ return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
501503 }
502504 else
503505 {
504506 var type = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . DataViewType . RawType ;
505507 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
506- return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
508+ return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
509+ }
510+ }
511+
512+ public override Delegate [ ] CreateGetters ( DataViewRow input , Func < int , bool > activeOutput , out Action disposer )
513+ {
514+ Contracts . Assert ( input . Schema == InputSchema ) ;
515+
516+ OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher ( ) ;
517+
518+ int n = OutputColumns . Value . Length ;
519+ var result = new Delegate [ n ] ;
520+ for ( int i = 0 ; i < n ; i ++ )
521+ {
522+ if ( ! activeOutput ( i ) )
523+ continue ;
524+ result [ i ] = CreateGetter ( input , i , activeOutput , outputCacher ) ;
507525 }
526+ disposer = ( ) =>
527+ {
528+ outputCacher . Dispose ( ) ;
529+ } ;
530+ return result ;
508531 }
509532
510- private class OnnxRuntimeOutputCacher
533+ private sealed class OnnxRuntimeOutputCacher : IDisposable
511534 {
512535 public long Position ;
513- public Dictionary < string , NamedOnnxValue > Outputs ;
536+ public Dictionary < string , DisposableNamedOnnxValue > Outputs ;
537+ public IDisposableReadOnlyCollection < DisposableNamedOnnxValue > OutputOnnxValues ;
538+
514539 public OnnxRuntimeOutputCacher ( )
515540 {
516541 Position = - 1 ;
517- Outputs = new Dictionary < string , NamedOnnxValue > ( ) ;
542+ Outputs = new Dictionary < string , DisposableNamedOnnxValue > ( ) ;
543+ }
544+
545+ private bool _isDisposed ;
546+
547+ public void Dispose ( )
548+ {
549+ if ( _isDisposed )
550+ return ;
551+ OutputOnnxValues ? . Dispose ( ) ;
552+ _isDisposed = true ;
518553 }
519554 }
520555
@@ -529,21 +564,22 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
529564 inputNameOnnxValues . Add ( srcNamedOnnxValueGetters [ i ] . GetNamedOnnxValue ( ) ) ;
530565 }
531566
532- var outputNamedOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
533- Contracts . Assert ( outputNamedOnnxValues . Count > 0 ) ;
567+ outputCache . OutputOnnxValues ? . Dispose ( ) ;
568+ outputCache . OutputOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
569+ Contracts . Assert ( outputCache . OutputOnnxValues . Count > 0 ) ;
534570
535- foreach ( var outputNameOnnxValue in outputNamedOnnxValues )
571+ foreach ( var outputNameOnnxValue in outputCache . OutputOnnxValues )
536572 {
537573 outputCache . Outputs [ outputNameOnnxValue . Name ] = outputNameOnnxValue ;
538574 }
539575 outputCache . Position = position ;
540576 }
541577 }
542578
543- private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
579+ private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
580+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
544581 {
545582 Host . AssertValue ( input ) ;
546- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
547583 ValueGetter < VBuffer < T > > valueGetter = ( ref VBuffer < T > dst ) =>
548584 {
549585 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
@@ -558,10 +594,11 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
558594 return valueGetter ;
559595 }
560596
561- private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
597+ private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
598+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
562599 {
563600 Host . AssertValue ( input ) ;
564- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
601+
565602 ValueGetter < VBuffer < ReadOnlyMemory < char > > > valueGetter = ( ref VBuffer < ReadOnlyMemory < char > > dst ) =>
566603 {
567604 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
@@ -580,14 +617,15 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
580617 return valueGetter ;
581618 }
582619
583- private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
620+ private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
621+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
584622 {
585623 Host . AssertValue ( input ) ;
586- var outputCache = new OnnxRuntimeOutputCacher ( ) ;
624+
587625 ValueGetter < T > valueGetter = ( ref T dst ) =>
588626 {
589- UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
590- var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
627+ UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
628+ var namedOnnxValue = outputCacher . Outputs [ _parent . Outputs [ iinfo ] ] ;
591629 var trueValue = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => value . AsDictionary < string , float > ( ) ) ;
592630 var caster = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . Caster ;
593631 dst = ( T ) caster ( namedOnnxValue ) ;
0 commit comments