@@ -384,12 +384,11 @@ internal static (TF_DataType[] tfOutputTypes, DataViewType[] outputTypes, (Opera
384384 // If there are other dimension that are unknown the transformer will return a variable length vector.
385385 // This is the work around in absence of reshape transformer.
386386 var idims = shape . dims ;
387- int [ ] dims = shape . ndim > 0 ? idims . Skip ( idims [ 0 ] == - 1 ? 1 : 0 ) . ToArray ( ) : new [ ] { 0 } ;
387+ int [ ] dims = shape . ndim > 0 ? idims . Skip ( idims [ 0 ] == - 1 ? 1 : 0 ) . ToArray ( ) : new int [ 0 ] ;
388388 for ( int j = 0 ; j < dims . Length ; j ++ )
389389 dims [ j ] = dims [ j ] == - 1 ? 0 : dims [ j ] ;
390390 if ( dims == null || dims . Length == 0 )
391391 {
392- dims = new [ ] { 1 } ;
393392 outputTypes [ i ] = Tf2MlNetType ( tfOutputType ) ;
394393 }
395394 else
@@ -503,20 +502,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
503502 throw Host . Except ( "Variable length input columns not supported" ) ;
504503
505504 _isInputVector [ i ] = type is VectorDataViewType ;
506- if ( ! _isInputVector [ i ] )
507- throw Host . Except ( "Non-vector columns are not supported and should be loaded as vector columns of size 1" ) ;
508- vecType = ( VectorDataViewType ) type ;
509505 var expectedType = Tf2MlNetType ( _parent . TFInputTypes [ i ] ) ;
510506 if ( type . GetItemType ( ) != expectedType )
511507 throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _parent . Inputs [ i ] , expectedType . ToString ( ) , type . ToString ( ) ) ;
512508 var originalShape = _parent . TFInputShapes [ i ] ;
513509 var shape = originalShape . dims ;
514510
515- var colTypeDims = vecType . Dimensions . Select ( dim => ( int ) dim ) . ToArray ( ) ;
516511 if ( shape == null || ( shape . Length == 0 ) )
517- _fullySpecifiedShapes [ i ] = new TensorShape ( colTypeDims ) ;
512+ _fullySpecifiedShapes [ i ] = new TensorShape ( ) ;
518513 else
519514 {
515+ vecType = ( VectorDataViewType ) type ;
516+ var colTypeDims = vecType . Dimensions . Select ( dim => ( int ) dim ) . ToArray ( ) ;
520517 // If the column is one dimension we make sure that the total size of the TF shape matches.
521518 // Compute the total size of the known dimensions of the shape.
522519 int valCount = 1 ;
@@ -561,7 +558,10 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
561558
562559 if ( _parent . _addBatchDimensionInput )
563560 {
564- var l = new int [ _fullySpecifiedShapes [ i ] . ndim + 1 ] ;
561+ // ndim of default TensorShape is -1, make originDim to 0 in this case.
562+ // after addBatchDimension, input column will be changed: type -> type[]
563+ var originDim = _fullySpecifiedShapes [ i ] . ndim < 0 ? 0 : _fullySpecifiedShapes [ i ] . ndim ;
564+ var l = new int [ originDim + 1 ] ;
565565 l [ 0 ] = 1 ;
566566 for ( int ishape = 1 ; ishape < l . Length ; ishape ++ )
567567 l [ ishape ] = _fullySpecifiedShapes [ i ] . dims [ ishape - 1 ] ;
@@ -729,11 +729,10 @@ public TensorValueGetter(DataViewRow input, int colIndex, TensorShape tfShape)
729729 {
730730 _srcgetter = input . GetGetter < T > ( input . Schema [ colIndex ] ) ;
731731 _tfShape = tfShape ;
732- long size = 0 ;
732+ long size = 1 ;
733733 _position = 0 ;
734- if ( tfShape . dims . Length != 0 )
734+ if ( tfShape . dims != null && tfShape . dims . Length != 0 )
735735 {
736- size = 1 ;
737736 foreach ( var dim in tfShape . dims )
738737 size *= dim ;
739738 }
@@ -744,8 +743,7 @@ public Tensor GetTensor()
744743 {
745744 var scalar = default ( T ) ;
746745 _srcgetter ( ref scalar ) ;
747- var tensor = new Tensor ( new [ ] { scalar } ) ;
748- tensor . set_shape ( _tfShape ) ;
746+ var tensor = TensorFlowUtils . CastDataAndReturnAsTensor ( scalar ) ;
749747 return tensor ;
750748 }
751749
@@ -928,8 +926,6 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
928926 var input = _options . InputColumns [ i ] ;
929927 if ( ! inputSchema . TryFindColumn ( input , out var col ) )
930928 throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input ) ;
931- if ( ! ( col . Kind == SchemaShape . Column . VectorKind . Vector ) )
932- throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , "vector" , col . GetTypeString ( ) ) ;
933929 var expectedType = Tf2MlNetType ( _tfInputTypes [ i ] ) ;
934930 if ( col . ItemType != expectedType )
935931 throw _host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , input , expectedType . ToString ( ) , col . ItemType . ToString ( ) ) ;
0 commit comments