@@ -385,8 +385,13 @@ private static IEnumerable<int> AdjustDimensions(OnnxShape shape)
385385 {
386386 if ( shape . Count > 0 )
387387 {
388- return shape . Select ( x => ( x <= 0 ) ? 1 : x ) ;
388+ if ( shape [ 0 ] < 0 )
389+ {
390+ shape [ 0 ] = 1 ;
391+ }
392+ return shape . Select ( x => ( x <= 0 ) ? 0 : x ) ;
389393 }
394+
390395 return new [ ] { 1 } ;
391396 }
392397
@@ -444,6 +449,11 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
444449 var shape = inputNodeInfo . Shape ;
445450
446451 var inputShape = AdjustDimensions ( inputNodeInfo . Shape ) ;
452+
453+ // Only allow a single unkown size dimension
454+ if ( inputShape . Where ( x => x == 0 ) . Count ( ) > 1 )
455+ throw new ArgumentOutOfRangeException ( _parent . Inputs [ i ] , "Only 1 unknown dimension is allowed" ) ;
456+
447457 _inputTensorShapes [ i ] = inputShape . ToList ( ) ;
448458 _inputOnnxTypes [ i ] = inputNodeInfo . TypeInOnnxRuntime ;
449459
@@ -456,9 +466,6 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
456466 var type = inputSchema [ _inputColIndices [ i ] ] . Type ;
457467 var vectorType = type as VectorDataViewType ;
458468
459- if ( vectorType != null && vectorType . Size == 0 )
460- throw Host . Except ( $ "Variable length input columns not supported") ;
461-
462469 var itemType = type . GetItemType ( ) ;
463470 var nodeItemType = inputNodeInfo . DataViewType . GetItemType ( ) ;
464471 if ( itemType != nodeItemType )
@@ -474,11 +481,14 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
474481
475482 // If the column is one dimension we make sure that the total size of the Onnx shape matches.
476483 // Compute the total size of the known dimensions of the shape.
477- int valCount = inputShape . Where ( x => x > 0 ) . Aggregate ( ( x , y ) => x * y ) ;
478- // The column length should be divisible by this, so that the other dimensions can be integral.
479- int typeValueCount = type . GetValueCount ( ) ;
480- if ( typeValueCount % valCount != 0 )
481- throw Contracts . Except ( $ "Input shape mismatch: Input '{ _parent . Inputs [ i ] } ' has shape { String . Join ( "," , inputShape ) } , but input data is of length { typeValueCount } .") ;
484+ if ( ! inputShape . Any ( x => x == 0 ) )
485+ {
486+ int valCount = inputShape . Where ( x => x > 0 ) . Aggregate ( ( x , y ) => x * y ) ;
487+ // The column length should be divisible by this, so that the other dimensions can be integral.
488+ int typeValueCount = type . GetValueCount ( ) ;
489+ if ( typeValueCount % valCount != 0 )
490+ throw Contracts . Except ( $ "Input shape mismatch: Input '{ _parent . Inputs [ i ] } ' has shape { String . Join ( "," , inputShape ) } , but input data is of length { typeValueCount } .") ;
491+ }
482492 }
483493 }
484494
@@ -781,23 +791,56 @@ public NamedOnnxValue GetNamedOnnxValue()
781791
782792 private class NamedOnnxValueGetterVec < T > : INamedOnnxValueGetter
783793 {
794+ private delegate NamedOnnxValue GetNamedOnnxVal ( ) ;
795+
784796 private readonly ValueGetter < VBuffer < T > > _srcGetter ;
785797 private readonly OnnxShape _tensorShape ;
786798 private readonly string _colName ;
787799 private VBuffer < T > _vBuffer ;
788800 private VBuffer < T > _vBufferDense ;
801+ private readonly int _denominator ;
802+ private readonly int _zeroIndex ;
803+ private readonly GetNamedOnnxVal _namedOnnxValueDelegate ;
804+
789805 public NamedOnnxValueGetterVec ( DataViewRow input , int colIndex , OnnxShape tensorShape )
790806 {
791807 _srcGetter = input . GetGetter < VBuffer < T > > ( input . Schema [ colIndex ] ) ;
792808 _tensorShape = tensorShape ;
793809 _colName = input . Schema [ colIndex ] . Name ;
794810 _vBuffer = default ;
795811 _vBufferDense = default ;
812+ _denominator = _tensorShape . Where ( x => x > 0 ) . Aggregate ( ( a , x ) => a * x ) ;
813+ _zeroIndex = _tensorShape . IndexOf ( 0 ) ;
814+
815+ var isKnownSize = ( input . Schema [ colIndex ] . Type as VectorDataViewType ) . IsKnownSize ;
816+
817+ if ( isKnownSize )
818+ _namedOnnxValueDelegate = GetNamedOnnxValueKnownSize ;
819+ else
820+ _namedOnnxValueDelegate = GetNamedOnnxValueUnknownSize ;
796821 }
797822 public NamedOnnxValue GetNamedOnnxValue ( )
823+ {
824+ return _namedOnnxValueDelegate ( ) ;
825+ }
826+
827+ private void GetNamedOnnxValueCore ( )
798828 {
799829 _srcGetter ( ref _vBuffer ) ;
800830 _vBuffer . CopyToDense ( ref _vBufferDense ) ;
831+ }
832+
833+ private NamedOnnxValue GetNamedOnnxValueKnownSize ( )
834+ {
835+ GetNamedOnnxValueCore ( ) ;
836+ return OnnxUtils . CreateNamedOnnxValue ( _colName , _vBufferDense . GetValues ( ) , _tensorShape ) ;
837+ }
838+
839+ private NamedOnnxValue GetNamedOnnxValueUnknownSize ( )
840+ {
841+ GetNamedOnnxValueCore ( ) ;
842+
843+ _tensorShape [ _zeroIndex ] = _vBufferDense . Length / _denominator ;
801844 return OnnxUtils . CreateNamedOnnxValue ( _colName , _vBufferDense . GetValues ( ) , _tensorShape ) ;
802845 }
803846 }
@@ -908,14 +951,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
908951 // Get the i-th IDataView input column's name in the underlying ONNX transformer.
909952 var input = Transformer . Inputs [ i ] ;
910953
954+ // Only allow 1 unknown dimension
955+ if ( Transformer . Model . ModelInfo . InputsInfo [ i ] . Shape . Where ( x => x == 0 ) . Count ( ) > 1 )
956+ throw new ArgumentOutOfRangeException ( input , "Only 1 unknown dimension is allowed" ) ;
957+
911958 // Make sure inputSchema contains the i-th input column.
912959 if ( ! inputSchema . TryFindColumn ( input , out var col ) )
913960 throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input ) ;
914961
915- // Make sure that the input columns in inputSchema are fixed shape tensors.
916- if ( col . Kind == SchemaShape . Column . VectorKind . VariableVector )
917- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , "vector" , col . GetTypeString ( ) ) ;
918-
919962 var inputsInfo = Transformer . Model . ModelInfo . InputsInfo ;
920963 var idx = Transformer . Model . ModelInfo . InputNames . IndexOf ( input ) ;
921964 if ( idx < 0 )
0 commit comments