@@ -108,21 +108,22 @@ public ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
108108 using ( var ch = env . Register ( "SchemaBindableWrapper" ) . Start ( "Bind" ) )
109109 {
110110 ch . CheckValue ( schema , nameof ( schema ) ) ;
111- ch . CheckParam ( schema . Feature != null , nameof ( schema ) , "Need a features column" ) ;
112- // Ensure that the feature column type is compatible with the needed input type.
113- var type = schema . Feature . Type ;
114- var typeIn = ValueMapper != null ? ValueMapper . InputType : new VectorType ( NumberType . Float ) ;
115- if ( type != typeIn )
111+ if ( schema . Feature != null )
116112 {
117- if ( ! type . ItemType . Equals ( typeIn . ItemType ) )
118- throw ch . Except ( "Incompatible features column type item type: '{0}' vs '{1}'" , type . ItemType , typeIn . ItemType ) ;
119- if ( type . IsVector != typeIn . IsVector )
120- throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
121- // typeIn can legally have unknown size.
122- if ( type . VectorSize != typeIn . VectorSize && typeIn . VectorSize > 0 )
123- throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
113+ // Ensure that the feature column type is compatible with the needed input type.
114+ var type = schema . Feature . Type ;
115+ var typeIn = ValueMapper != null ? ValueMapper . InputType : new VectorType ( NumberType . Float ) ;
116+ if ( type != typeIn )
117+ {
118+ if ( ! type . ItemType . Equals ( typeIn . ItemType ) )
119+ throw ch . Except ( "Incompatible features column type item type: '{0}' vs '{1}'" , type . ItemType , typeIn . ItemType ) ;
120+ if ( type . IsVector != typeIn . IsVector )
121+ throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
122+ // typeIn can legally have unknown size.
123+ if ( type . VectorSize != typeIn . VectorSize && typeIn . VectorSize > 0 )
124+ throw ch . Except ( "Incompatible features column type: '{0}' vs '{1}'" , type , typeIn ) ;
125+ }
124126 }
125-
126127 var mapper = BindCore ( ch , schema ) ;
127128 ch . Done ( ) ;
128129 return mapper ;
@@ -463,15 +464,18 @@ public CalibratedRowMapper(RoleMappedSchema schema, SchemaBindableBinaryPredicto
463464 Contracts . AssertValue ( parent ) ;
464465 Contracts . Assert ( parent . _distMapper != null ) ;
465466 Contracts . AssertValue ( schema ) ;
466- Contracts . AssertValue ( schema . Feature ) ;
467+ Contracts . AssertValueOrNull ( schema . Feature ) ;
467468
468469 _parent = parent ;
469470 _inputSchema = schema ;
470471 _outputSchema = new BinaryClassifierSchema ( ) ;
471472
472- var typeSrc = _inputSchema . Feature . Type ;
473- Contracts . Check ( typeSrc . IsKnownSizeVector && typeSrc . ItemType == NumberType . Float ,
474- "Invalid feature column type" ) ;
473+ if ( schema . Feature != null )
474+ {
475+ var typeSrc = _inputSchema . Feature . Type ;
476+ Contracts . Check ( typeSrc . IsKnownSizeVector && typeSrc . ItemType == NumberType . Float ,
477+ "Invalid feature column type" ) ;
478+ }
475479 }
476480
477481 public RoleMappedSchema InputSchema { get { return _inputSchema ; } }
@@ -484,15 +488,15 @@ public Func<int, bool> GetDependencies(Func<int, bool> predicate)
484488 {
485489 for ( int i = 0 ; i < OutputSchema . ColumnCount ; i ++ )
486490 {
487- if ( predicate ( i ) )
491+ if ( predicate ( i ) && _inputSchema . Feature != null )
488492 return col => col == _inputSchema . Feature . Index ;
489493 }
490494 return col => false ;
491495 }
492496
493497 public IEnumerable < KeyValuePair < RoleMappedSchema . ColumnRole , string > > GetInputColumnRoles ( )
494498 {
495- yield return RoleMappedSchema . ColumnRole . Feature . Bind ( _inputSchema . Feature . Name ) ;
499+ yield return RoleMappedSchema . ColumnRole . Feature . Bind ( _inputSchema . Feature != null ? _inputSchema . Feature . Name : null ) ;
496500 }
497501
498502 private Delegate [ ] CreateGetters ( IRow input , bool [ ] active )
@@ -504,7 +508,7 @@ private Delegate[] CreateGetters(IRow input, bool[] active)
504508 if ( active [ 0 ] || active [ 1 ] )
505509 {
506510 // Put all captured locals at this scope.
507- var featureGetter = input . GetGetter < VBuffer < Float > > ( _inputSchema . Feature . Index ) ;
511+ var featureGetter = _inputSchema . Feature != null ? input . GetGetter < VBuffer < Float > > ( _inputSchema . Feature . Index ) : null ;
508512 Float prob = 0 ;
509513 Float score = 0 ;
510514 long cachedPosition = - 1 ;
@@ -543,7 +547,9 @@ private static void EnsureCachedResultValueMapper(ValueMapper<VBuffer<Float>, Fl
543547 Contracts . AssertValue ( mapper ) ;
544548 if ( cachedPosition != input . Position )
545549 {
546- featureGetter ( ref features ) ;
550+ if ( featureGetter != null )
551+ featureGetter ( ref features ) ;
552+
547553 mapper ( ref features , ref score , ref prob ) ;
548554 cachedPosition = input . Position ;
549555 }
0 commit comments