@@ -362,8 +362,12 @@ public void Dispose()
362362 _isDisposed = true ;
363363 }
364364
365- private sealed class Mapper : MapperBase
365+ private sealed class Mapper : IRowMapper
366366 {
367+ private readonly IHost _host ;
368+ private readonly DataViewSchema _inputSchema ;
369+ private readonly Lazy < DataViewSchema . DetachedColumn [ ] > _outputColumns ;
370+
367371 private readonly OnnxTransformer _parent ;
368372 /// <summary>
369373 /// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
@@ -379,9 +383,11 @@ private sealed class Mapper : MapperBase
379383 /// </summary>
380384 private readonly Type [ ] _inputOnnxTypes ;
381385
382- public Mapper ( OnnxTransformer parent , DataViewSchema inputSchema ) :
383- base ( Contracts . CheckRef ( parent , nameof ( parent ) ) . Host . Register ( nameof ( Mapper ) ) , inputSchema , parent )
386+ public Mapper ( OnnxTransformer parent , DataViewSchema inputSchema )
384387 {
388+ _host = Contracts . CheckRef ( parent , nameof ( parent ) ) . Host . Register ( nameof ( Mapper ) ) ;
389+ _inputSchema = inputSchema ;
390+ _outputColumns = new Lazy < DataViewSchema . DetachedColumn [ ] > ( GetOutputColumnsCore ) ;
385391
386392 _parent = parent ;
387393 _inputColIndices = new int [ _parent . Inputs . Length ] ;
@@ -401,15 +407,15 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
401407
402408 var col = inputSchema . GetColumnOrNull ( _parent . Inputs [ i ] ) ;
403409 if ( ! col . HasValue )
404- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _parent . Inputs [ i ] ) ;
410+ throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _parent . Inputs [ i ] ) ;
405411
406412 _inputColIndices [ i ] = col . Value . Index ;
407413
408414 var type = inputSchema [ _inputColIndices [ i ] ] . Type ;
409415 var vectorType = type as VectorDataViewType ;
410416
411417 if ( vectorType != null && vectorType . Size == 0 )
412- throw Host . Except ( $ "Variable length input columns not supported") ;
418+ throw _host . Except ( $ "Variable length input columns not supported") ;
413419
414420 var itemType = type . GetItemType ( ) ;
415421 var nodeItemType = inputNodeInfo . DataViewType . GetItemType ( ) ;
@@ -421,7 +427,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
421427 // This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
422428 var isKeyType = itemType is KeyDataViewType ;
423429 if ( ! isKeyType || itemType . RawType != nodeItemType . RawType )
424- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _parent . Inputs [ i ] , inputNodeInfo . DataViewType . GetItemType ( ) . ToString ( ) , type . ToString ( ) ) ;
430+ throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _parent . Inputs [ i ] , inputNodeInfo . DataViewType . GetItemType ( ) . ToString ( ) , type . ToString ( ) ) ;
425431 }
426432
427433 // If the column is one dimension we make sure that the total size of the Onnx shape matches.
@@ -433,8 +439,9 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
433439 throw Contracts . Except ( $ "Input shape mismatch: Input '{ _parent . Inputs [ i ] } ' has shape { String . Join ( "," , inputShape ) } , but input data is of length { typeValueCount } .") ;
434440 }
435441 }
442+ DataViewSchema . DetachedColumn [ ] IRowMapper . GetOutputColumns ( ) => _outputColumns . Value ;
436443
437- protected override DataViewSchema . DetachedColumn [ ] GetOutputColumnsCore ( )
444+ private DataViewSchema . DetachedColumn [ ] GetOutputColumnsCore ( )
438445 {
439446 var stdSuffix = ".output" ;
440447 var info = new DataViewSchema . DetachedColumn [ _parent . Outputs . Length ] ;
@@ -476,17 +483,16 @@ private void AddSlotNames(string columnName, DataViewSchema.Annotations.Builder
476483 builder . AddSlotNames ( count , getter ) ;
477484 }
478485
479- private protected override Func < int , bool > GetDependenciesCore ( Func < int , bool > activeOutput )
486+ private Func < int , bool > GetDependenciesCore ( Func < int , bool > activeOutput )
480487 {
481488 return col => Enumerable . Range ( 0 , _parent . Outputs . Length ) . Any ( i => activeOutput ( i ) ) && _inputColIndices . Any ( i => i == col ) ;
482489 }
483490
484- private protected override void SaveModel ( ModelSaveContext ctx ) => _parent . SaveModel ( ctx ) ;
491+ private void SaveModel ( ModelSaveContext ctx ) => _parent . SaveModel ( ctx ) ;
485492
486- protected override Delegate MakeGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , out Action disposer )
493+ private Delegate MakeGetter ( DataViewRow input , int iinfo , Func < int , bool > activeOutput , OnnxRuntimeOutputCacher outputCacher )
487494 {
488- disposer = null ;
489- Host . AssertValue ( input ) ;
495+ _host . AssertValue ( input ) ;
490496
491497 var activeOutputColNames = _parent . Outputs . Where ( ( x , i ) => activeOutput ( i ) ) . ToArray ( ) ;
492498
@@ -495,26 +501,65 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
495501 var elemRawType = vectorType . ItemType . RawType ;
496502 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
497503 if ( vectorType . ItemType is TextDataViewType )
498- return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
504+ return MakeStringTensorGetter ( input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
499505 else
500- return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
506+ return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
501507 }
502508 else
503509 {
504510 var type = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . DataViewType . RawType ;
505511 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _inputColIndices , _inputOnnxTypes , _inputTensorShapes ) ;
506- return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames ) ;
512+ return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
513+ }
514+ }
515+
516+ Delegate [ ] IRowMapper . CreateGetters ( DataViewRow input , Func < int , bool > activeOutput , out Action disposer )
517+ {
518+ Contracts . Assert ( input . Schema == _inputSchema ) ;
519+
520+ OnnxRuntimeOutputCacher outputCacher = new OnnxRuntimeOutputCacher ( ) ;
521+
522+ int n = _outputColumns . Value . Length ;
523+ var result = new Delegate [ n ] ;
524+ for ( int i = 0 ; i < n ; i ++ )
525+ {
526+ if ( ! activeOutput ( i ) )
527+ continue ;
528+ result [ i ] = MakeGetter ( input , i , activeOutput , outputCacher ) ;
507529 }
530+ disposer = ( ) =>
531+ {
532+ outputCacher . Dispose ( ) ;
533+ } ;
534+ return result ;
508535 }
509536
510- private class OnnxRuntimeOutputCacher
537+ internal class OnnxRuntimeOutputCacher : IDisposable
511538 {
512539 public long Position ;
513- public Dictionary < string , NamedOnnxValue > Outputs ;
540+ public Dictionary < string , DisposableNamedOnnxValue > Outputs ;
541+ public IDisposableReadOnlyCollection < DisposableNamedOnnxValue > OutputOnnxValues ;
542+
514543 public OnnxRuntimeOutputCacher ( )
515544 {
516545 Position = - 1 ;
517- Outputs = new Dictionary < string , NamedOnnxValue > ( ) ;
546+ Outputs = new Dictionary < string , DisposableNamedOnnxValue > ( ) ;
547+ }
548+
549+ private bool _isDisposed ;
550+
551+ protected virtual void Dispose ( bool disposing )
552+ {
553+ if ( _isDisposed )
554+ return ;
555+ OutputOnnxValues ? . Dispose ( ) ;
556+ _isDisposed = true ;
557+ }
558+
559+ public void Dispose ( )
560+ {
561+ Dispose ( disposing : true ) ;
562+ GC . SuppressFinalize ( this ) ;
518563 }
519564 }
520565
@@ -529,46 +574,47 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed
529574 inputNameOnnxValues . Add ( srcNamedOnnxValueGetters [ i ] . GetNamedOnnxValue ( ) ) ;
530575 }
531576
532- var outputNamedOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
533- Contracts . Assert ( outputNamedOnnxValues . Count > 0 ) ;
577+ outputCache . OutputOnnxValues ? . Dispose ( ) ;
578+ outputCache . OutputOnnxValues = _parent . Model . Run ( inputNameOnnxValues ) ;
579+ Contracts . Assert ( outputCache . OutputOnnxValues . Count > 0 ) ;
534580
535- foreach ( var outputNameOnnxValue in outputNamedOnnxValues )
581+ foreach ( var outputNameOnnxValue in outputCache . OutputOnnxValues )
536582 {
537583 outputCache . Outputs [ outputNameOnnxValue . Name ] = outputNameOnnxValue ;
538584 }
539585 outputCache . Position = position ;
540586 }
541587 }
542588
543- private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
589+ private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
590+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
544591 {
545- Host . AssertValue ( input ) ;
546- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
592+ _host . AssertValue ( input ) ;
547593 ValueGetter < VBuffer < T > > valueGetter = ( ref VBuffer < T > dst ) =>
548594 {
549595 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
550596 var namedOnnxValue = outputCacher . Outputs [ _parent . Outputs [ iinfo ] ] ;
551597 var tensor = namedOnnxValue . AsTensor < T > ( ) as Microsoft . ML . OnnxRuntime . Tensors . DenseTensor < T > ;
552598 if ( tensor == null )
553- throw Host . Except ( $ "Output column { namedOnnxValue . Name } doesn't contain a DenseTensor of expected type { typeof ( T ) } ") ;
599+ throw _host . Except ( $ "Output column { namedOnnxValue . Name } doesn't contain a DenseTensor of expected type { typeof ( T ) } ") ;
554600 var editor = VBufferEditor . Create ( ref dst , ( int ) tensor . Length ) ;
555601 tensor . Buffer . Span . CopyTo ( editor . Values ) ;
556602 dst = editor . Commit ( ) ;
557603 } ;
558604 return valueGetter ;
559605 }
560606
561- private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
607+ private Delegate MakeStringTensorGetter ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
608+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
562609 {
563- Host . AssertValue ( input ) ;
564- var outputCacher = new OnnxRuntimeOutputCacher ( ) ;
610+ _host . AssertValue ( input ) ;
565611 ValueGetter < VBuffer < ReadOnlyMemory < char > > > valueGetter = ( ref VBuffer < ReadOnlyMemory < char > > dst ) =>
566612 {
567613 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
568614 var namedOnnxValue = outputCacher . Outputs [ _parent . Outputs [ iinfo ] ] ;
569615 var tensor = namedOnnxValue . AsTensor < string > ( ) as Microsoft . ML . OnnxRuntime . Tensors . DenseTensor < string > ;
570616 if ( tensor == null )
571- throw Host . Except ( $ "Output column { namedOnnxValue . Name } doesn't contain a DenseTensor of expected type { typeof ( string ) } ") ;
617+ throw _host . Except ( $ "Output column { namedOnnxValue . Name } doesn't contain a DenseTensor of expected type { typeof ( string ) } ") ;
572618
573619 // Create VBufferEditor to fill "dst" with the values in "denseTensor".
574620 var editor = VBufferEditor . Create ( ref dst , ( int ) tensor . Length ) ;
@@ -580,14 +626,14 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx
580626 return valueGetter ;
581627 }
582628
583- private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames )
629+ private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters ,
630+ string [ ] activeOutputColNames , OnnxRuntimeOutputCacher outputCacher )
584631 {
585- Host . AssertValue ( input ) ;
586- var outputCache = new OnnxRuntimeOutputCacher ( ) ;
632+ _host . AssertValue ( input ) ;
587633 ValueGetter < T > valueGetter = ( ref T dst ) =>
588634 {
589- UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
590- var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
635+ UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCacher ) ;
636+ var namedOnnxValue = outputCacher . Outputs [ _parent . Outputs [ iinfo ] ] ;
591637 var trueValue = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => value . AsDictionary < string , float > ( ) ) ;
592638 var caster = _parent . Model . ModelInfo . OutputsInfo [ _parent . MapDataViewColumnToOnnxOutputTensor ( iinfo ) ] . Caster ;
593639 dst = ( T ) caster ( namedOnnxValue ) ;
@@ -664,6 +710,12 @@ private static INamedOnnxValueGetter CreateNamedOnnxValueGetterVecCore<T>(DataVi
664710 return new NamedOnnxValueGetterVec < T > ( input , colIndex , onnxShape ) ;
665711 }
666712
713+ void ICanSaveModel . Save ( ModelSaveContext ctx ) => SaveModel ( ctx ) ;
714+
715+ Func < int , bool > IRowMapper . GetDependencies ( Func < int , bool > activeOutput ) => GetDependenciesCore ( activeOutput ) ;
716+
717+ public ITransformer GetTransformer ( ) => _parent ;
718+
667719 /// <summary>
668720 /// Common function for wrapping ML.NET getter as a NamedOnnxValue getter.
669721 /// </summary>
0 commit comments