@@ -92,9 +92,17 @@ internal sealed class Options : TransformInputBase
9292 internal const string ShortName = "Onnx" ;
9393 internal const string LoaderSignature = "OnnxTransform" ;
9494
95- internal readonly string [ ] Inputs ;
96- internal readonly string [ ] Outputs ;
97- internal readonly DataViewType [ ] OutputTypes ;
95+ /// <summary>
96+ /// Input column names from ML.NET's perspective. It can be ordered differently than ONNX model's input list.
97+ /// It's also possible that the <see cref="Inputs"/> contains less variables than ONNX model's input list.
98+ /// </summary>
99+ internal string [ ] Inputs { get ; }
100+ /// <summary>
101+ /// Output column names from ML.NET's perspective. It can be ordered differently than ONNX model's output list.
102+ /// It's also possible that the <see cref="Outputs"/> contains less variables than ONNX model's output list.
103+ /// </summary>
104+ internal string [ ] Outputs { get ; }
105+ internal DataViewType [ ] OutputTypes { get ; }
98106
99107 private static VersionInfo GetVersionInfo ( )
100108 {
@@ -196,7 +204,7 @@ private OnnxTransformer(IHostEnvironment env, Options options, byte[] modelBytes
196204 var shape = outputNodeInfo . Shape ;
197205 var dims = AdjustDimensions ( shape ) ;
198206 // OutputTypes[i] = new VectorDataViewType(OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type), dims.ToArray());
199- OutputTypes [ i ] = Model . OutputTypes [ i ] ;
207+ OutputTypes [ i ] = Model . ModelInfo . OutputsInfo [ idx ] . MlnetType ;
200208 }
201209 _options = options ;
202210 }
@@ -302,9 +310,22 @@ private static IEnumerable<int> AdjustDimensions(OnnxShape shape)
302310 private sealed class Mapper : MapperBase
303311 {
304312 private readonly OnnxTransformer _parent ;
313+ /// <summary>
314+ /// <see cref="_inputColIndices"/>'s i-th element value tells the <see cref="IDataView"/> column index to
315+ /// find the i-th ONNX input.
316+ /// </summary>
305317 private readonly int [ ] _inputColIndices ;
318+ /// <summary>
319+ /// <see cref="_isInputVector"/>'s i-th element value tells if the i-th ONNX input is a tensor.
320+ /// </summary>
306321 private readonly bool [ ] _isInputVector ;
322+ /// <summary>
323+ /// <see cref="_inputTensorShapes"/>'s i-th element value tells if the i-th ONNX input's shape if it's a tensor.
324+ /// </summary>
307325 private readonly OnnxShape [ ] _inputTensorShapes ;
326+ /// <summary>
327+ /// <see cref="_inputOnnxTypes"/>'s i-th element value tells if the <see cref="Type"/> of the i-th ONNX input.
328+ /// </summary>
308329 private readonly System . Type [ ] _inputOnnxTypes ;
309330
310331 public Mapper ( OnnxTransformer parent , DataViewSchema inputSchema ) :
@@ -327,11 +348,11 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
327348 var inputNodeInfo = model . ModelInfo . InputsInfo [ idx ] ;
328349
329350 var shape = inputNodeInfo . Shape ;
330- var inputType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . Type ) ;
351+ var inputType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . OrtType ) ;
331352
332353 var inputShape = AdjustDimensions ( inputNodeInfo . Shape ) ;
333354 _inputTensorShapes [ i ] = inputShape . ToList ( ) ;
334- _inputOnnxTypes [ i ] = inputNodeInfo . Type ;
355+ _inputOnnxTypes [ i ] = inputNodeInfo . OrtType ;
335356
336357 var col = inputSchema . GetColumnOrNull ( _parent . Inputs [ i ] ) ;
337358 if ( ! col . HasValue )
@@ -417,22 +438,21 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
417438 {
418439 disposer = null ;
419440 Host . AssertValue ( input ) ;
420- //Host.Assert(typeof(T) == _outputItemRawType);
421441
422442 var outputCache = new OutputCache ( ) ;
423443 var activeOutputColNames = _parent . Outputs . Where ( ( x , i ) => activeOutput ( i ) ) . ToArray ( ) ;
424444
425- if ( _parent . Model . OutputTypes [ iinfo ] is VectorDataViewType )
445+ if ( _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . MlnetType is VectorDataViewType vectorType )
426446 {
427447 //var type = _parent.OutputTypes[iinfo].RawType;
428- var type = OnnxUtils . OnnxToMlNetType ( _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . Type ) . RawType ;
448+ var elemRawType = vectorType . ItemType . RawType ;
429449 //Host.Assert(type == _parent.OutputTypes[iinfo].GetItemType().RawType);
430450 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _parent . Inputs , _inputColIndices , _isInputVector , _inputOnnxTypes , _inputTensorShapes ) ;
431- return Utils . MarshalInvoke ( MakeTensorGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
451+ return Utils . MarshalInvoke ( MakeTensorGetter < int > , elemRawType , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
432452 }
433453 else
434454 {
435- var type = _parent . Model . OutputTypes [ iinfo ] . RawType ;
455+ var type = _parent . Model . ModelInfo . OutputsInfo [ iinfo ] . MlnetType . RawType ;
436456 var srcNamedValueGetters = GetNamedOnnxValueGetters ( input , _parent . Inputs , _inputColIndices , _isInputVector , _inputOnnxTypes , _inputTensorShapes ) ;
437457 return Utils . MarshalInvoke ( MakeObjectGetter < int > , type , input , iinfo , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
438458 }
@@ -441,7 +461,7 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b
441461 private Delegate MakeTensorGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames , OutputCache outputCache )
442462 {
443463 Host . AssertValue ( input ) ;
444- ValueGetter < VBuffer < T > > valuegetter = ( ref VBuffer < T > dst ) =>
464+ ValueGetter < VBuffer < T > > valueGetter = ( ref VBuffer < T > dst ) =>
445465 {
446466 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
447467 var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
@@ -452,20 +472,20 @@ private Delegate MakeTensorGetter<T>(DataViewRow input, int iinfo, INamedOnnxVal
452472 denseTensor . Buffer . Span . CopyTo ( editor . Values ) ;
453473 dst = editor . Commit ( ) ;
454474 } ;
455- return valuegetter ;
475+ return valueGetter ;
456476 }
457477
458478 private Delegate MakeObjectGetter < T > ( DataViewRow input , int iinfo , INamedOnnxValueGetter [ ] srcNamedValueGetters , string [ ] activeOutputColNames , OutputCache outputCache )
459479 {
460480 Host . AssertValue ( input ) ;
461- ValueGetter < T > valuegetter = ( ref T dst ) =>
481+ ValueGetter < T > valueGetter = ( ref T dst ) =>
462482 {
463483 UpdateCacheIfNeeded ( input . Position , srcNamedValueGetters , activeOutputColNames , outputCache ) ;
464484 var namedOnnxValue = outputCache . Outputs [ _parent . Outputs [ iinfo ] ] ;
465485 var trueValue = namedOnnxValue . AsEnumerable < NamedOnnxValue > ( ) . Select ( value => value . AsDictionary < string , float > ( ) ) ;
466486 dst = ( T ) trueValue ;
467487 } ;
468- return valuegetter ;
488+ return valueGetter ;
469489 }
470490
471491 private static INamedOnnxValueGetter [ ] GetNamedOnnxValueGetters ( DataViewRow input ,
@@ -634,7 +654,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
634654 throw Host . Except ( $ "Column { input } doesn't match input node names of model.") ;
635655
636656 var inputNodeInfo = inputsInfo [ idx ] ;
637- var expectedType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . Type ) ;
657+ var expectedType = OnnxUtils . OnnxToMlNetType ( inputNodeInfo . OrtType ) ;
638658 if ( col . ItemType != expectedType )
639659 throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , expectedType . ToString ( ) , col . ItemType . ToString ( ) ) ;
640660 }
0 commit comments