1717using Microsoft . ML . Runtime . Model ;
1818using Microsoft . ML . Runtime . Numeric ;
1919
20- [ assembly: LoadableClass ( typeof ( IDataScorerTransform ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . Arguments ) ,
20+ [ assembly: LoadableClass ( typeof ( IDataScorerTransform ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper . Arguments ) ,
2121 typeof ( SignatureDataScorer ) , "Feature Contribution Transform" , "fct" , "FeatureContributionCalculationTransform" , MetadataUtils . Const . ScoreColumnKind . FeatureContribution ) ]
2222
23- [ assembly: LoadableClass ( typeof ( ISchemaBindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . Arguments ) ,
23+ [ assembly: LoadableClass ( typeof ( ISchemaBindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper . Arguments ) ,
2424 typeof ( SignatureBindableMapper ) , "Feature Contribution Mapper" , "fct" , MetadataUtils . Const . ScoreColumnKind . FeatureContribution ) ]
2525
26- [ assembly: LoadableClass ( typeof ( ISchemaBindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , null , typeof ( SignatureLoadModel ) ,
27- "Feature Contribution Mapper" , FeatureContributionCalculatingTransformer . MapperLoaderSignature ) ]
26+ [ assembly: LoadableClass ( typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , typeof ( FeatureContributionCalculatingTransformer . BindableMapper ) , null , typeof ( SignatureLoadModel ) ,
27+ "Feature Contribution Mapper" , FeatureContributionCalculatingTransformer . BindableMapper . MapperLoaderSignature ) ]
2828
2929[ assembly: LoadableClass ( FeatureContributionCalculatingTransformer . Summary , typeof ( FeatureContributionCalculatingTransformer ) , null , typeof ( SignatureLoadModel ) ,
3030 FeatureContributionCalculatingTransformer . FriendlyName , FeatureContributionCalculatingTransformer . LoaderSignature ) ]
@@ -51,38 +51,12 @@ namespace Microsoft.ML.Runtime.Data
5151 /// </example>
5252 public sealed class FeatureContributionCalculatingTransformer : RowToRowTransformerBase
5353 {
54- public sealed class Arguments : ScorerArgumentsBase
55- {
56- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Number of top contributions" , SortOrder = 1 ) ]
57- public int Top = 10 ;
58-
59- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Number of bottom contributions" , SortOrder = 2 ) ]
60- public int Bottom = 10 ;
61-
62- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Whether or not output of Features contribution should be normalized" , ShortName = "norm" , SortOrder = 3 ) ]
63- public bool Normalize = true ;
64-
65- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Whether or not output of Features contribution in string key-value format" , ShortName = "str" , SortOrder = 4 ) ]
66- public bool Stringify = false ;
67-
68- // REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
69- }
70-
7154 // Apparently, loader signature is limited in length to 24 characters.
7255 internal const string Summary = "For each data point, calculates the contribution of individual features to the model prediction." ;
7356 internal const string FriendlyName = "Feature Contribution Transform" ;
7457 internal const string LoaderSignature = "FeatureContribution" ;
7558
76- internal const string MapperLoaderSignature = "WTFBindable" ;
77-
78- private const int MaxTopBottom = 1000 ;
79-
80- private readonly string _features ;
81- private readonly int _topContributionsCount ;
82- private readonly int _bottomContributionsCount ;
83- private readonly bool _normalize ;
84- private readonly bool _stringify ;
85- private readonly IFeatureContributionMapper _predictor ;
59+ private readonly string _featureColumn ;
8660 private readonly BindableMapper _mapper ;
8761
8862 private static VersionInfo GetVersionInfo ( )
@@ -97,23 +71,22 @@ private static VersionInfo GetVersionInfo()
9771 }
9872
9973 // TODO documentation
100- public FeatureContributionCalculatingTransformer ( IHostEnvironment env , IPredictor predictor , string featuresColumn , Arguments args )
74+ public FeatureContributionCalculatingTransformer ( IHostEnvironment env , IPredictor predictor , string featureColumn ,
75+ int top = FeatureContributionCalculatingEstimator . Defaults . Top ,
76+ int bottom = FeatureContributionCalculatingEstimator . Defaults . Bottom ,
77+ bool normalize = FeatureContributionCalculatingEstimator . Defaults . Normalize ,
78+ bool stringigy = FeatureContributionCalculatingEstimator . Defaults . Stringify )
10179 : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( FeatureContributionCalculatingTransformer ) ) )
10280 {
103- Host . CheckValue ( args , nameof ( args ) ) ;
10481 Host . CheckValue ( predictor , nameof ( predictor ) ) ;
82+ Host . CheckNonEmpty ( featureColumn , nameof ( featureColumn ) ) ;
10583
10684 var pred = predictor as IFeatureContributionMapper ;
10785 Host . CheckParam ( pred != null , nameof ( predictor ) , "Predictor doesn't support getting feature contributions" ) ;
10886
10987 // TODO check that the featues column is not empty.
110- _mapper = new BindableMapper ( Host , pred , args . Top , args . Bottom , args . Normalize , args . Stringify ) ;
111- _features = featuresColumn ;
112- _predictor = pred ;
113- _stringify = args . Stringify ;
114- _topContributionsCount = args . Top ;
115- _bottomContributionsCount = args . Bottom ;
116- _normalize = args . Normalize ;
88+ _featureColumn = featureColumn ;
89+ _mapper = new BindableMapper ( Host , pred , top , bottom , normalize , stringigy ) ;
11790 }
11891
11992 // Factory method for SignatureLoadModel
@@ -127,17 +100,8 @@ private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoa
127100 // string features
128101 // BindableMapper mapper
129102
130- // TODO use ctx.LoadModel with BindableMapper instead of this.
131- _features = ctx . LoadNonEmptyString ( ) ;
132- ctx . LoadModel < IFeatureContributionMapper , SignatureLoadModel > ( env , out _predictor , ModelFileUtils . DirPredictor ) ;
133- _topContributionsCount = ctx . Reader . ReadInt32 ( ) ;
134- Contracts . CheckDecode ( 0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom ) ;
135- _bottomContributionsCount = ctx . Reader . ReadInt32 ( ) ;
136- Contracts . CheckDecode ( 0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom ) ;
137- _normalize = ctx . Reader . ReadBoolByte ( ) ;
138- _stringify = ctx . Reader . ReadBoolByte ( ) ;
139-
140- _mapper = new BindableMapper ( env , _predictor , _topContributionsCount , _bottomContributionsCount , _normalize , _stringify ) ;
103+ _featureColumn = ctx . LoadNonEmptyString ( ) ;
104+ ctx . LoadModel < BindableMapper , SignatureLoadModel > ( env , out _mapper , ModelFileUtils . DirPredictor ) ;
141105 }
142106
143107 // Factory method for SignatureLoadRowMapper.
@@ -153,15 +117,8 @@ public override void Save(ModelSaveContext ctx)
153117 // string features
154118 // BindableMapper mapper
155119
156- ctx . SaveNonEmptyString ( _features ) ;
157- // TODO use ctx.SaveModel with BindableMapper instead of this.
158- ctx . SaveModel ( _predictor , ModelFileUtils . DirPredictor ) ;
159- Contracts . Assert ( 0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom ) ;
160- ctx . Writer . Write ( _topContributionsCount ) ;
161- Contracts . Assert ( 0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom ) ;
162- ctx . Writer . Write ( _bottomContributionsCount ) ;
163- ctx . Writer . WriteBoolByte ( _normalize ) ;
164- ctx . Writer . WriteBoolByte ( _stringify ) ;
120+ ctx . SaveNonEmptyString ( _featureColumn ) ;
121+ ctx . SaveModel ( _mapper , ModelFileUtils . DirPredictor ) ;
165122 }
166123
167124 private protected override IRowMapper MakeRowMapper ( Schema schema )
@@ -179,12 +136,11 @@ private class Mapper : MapperBase
179136 public Mapper ( FeatureContributionCalculatingTransformer parent , Schema schema )
180137 : base ( parent . Host , schema )
181138 {
182- // TODO some checks? get soem of the columns, initialize some stuff
183139 _parent = parent ;
184140 _bindableMapper = _parent . _mapper ;
185141
186142 var roles = new List < KeyValuePair < RoleMappedSchema . ColumnRole , string > > ( ) ;
187- roles . Add ( new KeyValuePair < RoleMappedSchema . ColumnRole , string > ( RoleMappedSchema . ColumnRole . Feature , _parent . _features ) ) ;
143+ roles . Add ( new KeyValuePair < RoleMappedSchema . ColumnRole , string > ( RoleMappedSchema . ColumnRole . Feature , _parent . _featureColumn ) ) ;
188144 _roleMappedSchema = new RoleMappedSchema ( InputSchema , roles ) ;
189145
190146 var genericMapper = _bindableMapper . GenericMapper . Bind ( Host , _roleMappedSchema ) ;
@@ -204,7 +160,7 @@ public Mapper(FeatureContributionCalculatingTransformer parent, Schema schema)
204160 private protected override Func < int , bool > GetDependenciesCore ( Func < int , bool > activeOutput )
205161 {
206162 var active = new bool [ InputSchema . ColumnCount ] ;
207- InputSchema . TryGetColumnIndex ( _parent . _features , out int featureCol ) ;
163+ InputSchema . TryGetColumnIndex ( _parent . _featureColumn , out int featureCol ) ;
208164 active [ featureCol ] = true ;
209165 return col => active [ col ] ;
210166 }
@@ -219,8 +175,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore()
219175 var result = new List < Schema . DetachedColumn > ( ) ;
220176
221177 // Add Score Column.
222- foreach ( var pair in _outputGenericSchema . GetColumns ( ) )
223- result . Add ( new Schema . DetachedColumn ( pair . column ) ) ;
178+ result . AddRange ( _outputGenericSchema . GetColumns ( ) . Select ( pair => new Schema . DetachedColumn ( pair . column ) ) ) ;
224179
225180 // Add FeatureContributions column.
226181 var builder = new MetadataBuilder ( ) ;
@@ -263,6 +218,23 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func<int, bool> act
263218 // TODO documentation
264219 internal sealed class BindableMapper : ISchemaBindableMapper , ICanSaveModel , IPredictor
265220 {
221+ public sealed class Arguments : ScorerArgumentsBase
222+ {
223+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Number of top contributions" , SortOrder = 1 ) ]
224+ public int Top = FeatureContributionCalculatingEstimator . Defaults . Top ;
225+
226+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Number of bottom contributions" , SortOrder = 2 ) ]
227+ public int Bottom = FeatureContributionCalculatingEstimator . Defaults . Bottom ;
228+
229+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Whether or not output of Features contribution should be normalized" , ShortName = "norm" , SortOrder = 3 ) ]
230+ public bool Normalize = FeatureContributionCalculatingEstimator . Defaults . Normalize ;
231+
232+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Whether or not output of Features contribution in string key-value format" , ShortName = "str" , SortOrder = 4 ) ]
233+ public bool Stringify = FeatureContributionCalculatingEstimator . Defaults . Stringify ;
234+
235+ // REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
236+ }
237+
266238 private readonly int _topContributionsCount ;
267239 private readonly int _bottomContributionsCount ;
268240 private readonly bool _normalize ;
@@ -272,6 +244,9 @@ internal sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPr
272244 public readonly ISchemaBindableMapper GenericMapper ;
273245 public readonly bool Stringify ;
274246
247+ internal const string MapperLoaderSignature = "WTFBindable" ;
248+ private const int MaxTopBottom = 1000 ;
249+
275250 private static VersionInfo GetVersionInfo ( )
276251 {
277252 return new VersionInfo (
@@ -604,47 +579,64 @@ public void GetMetadata<TValue>(string kind, int col, ref TValue value)
604579 // TODO DOcumentation
605580 public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator < FeatureContributionCalculatingTransformer >
606581 {
607- private readonly FeatureContributionCalculatingTransformer . Arguments _args ;
608- private readonly string _features ;
582+ private readonly string _featureColumn ;
609583 private readonly IPredictor _predictor ;
584+ private readonly bool _stringify ;
585+
586+ public static class Defaults
587+ {
588+ public const int Top = 10 ;
589+ public const int Bottom = 10 ;
590+ public const bool Normalize = true ;
591+ public const bool Stringify = false ;
592+ }
610593
611594 // TODO Documentation
612- public FeatureContributionCalculatingEstimator ( IHostEnvironment env , IPredictor predictor , string featuresColumn , FeatureContributionCalculatingTransformer . Arguments args )
613- : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( FeatureContributionCalculatingTransformer ) ) , new FeatureContributionCalculatingTransformer ( env , predictor , featuresColumn , args ) )
595+ public FeatureContributionCalculatingEstimator ( IHostEnvironment env , IPredictor predictor , string featureColumn ,
596+ int top = Defaults . Top ,
597+ int bottom = Defaults . Bottom ,
598+ bool normalize = Defaults . Normalize ,
599+ bool stringify = Defaults . Stringify )
600+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( FeatureContributionCalculatingTransformer ) ) ,
601+ new FeatureContributionCalculatingTransformer ( env , predictor , featureColumn , top , bottom , normalize , stringify ) )
614602 {
615- // TODO argcheck
616- _args = args ;
617- _features = featuresColumn ;
603+ _featureColumn = featureColumn ;
618604 _predictor = predictor ;
605+ _stringify = stringify ;
619606 }
620607
621608 public override SchemaShape GetOutputSchema ( SchemaShape inputSchema )
622609 {
623610 Host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
624611 var result = inputSchema . ToDictionary ( x => x . Name ) ;
625612
626- if ( ! inputSchema . TryFindColumn ( _features , out var col ) )
627- throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _features ) ;
628- var metadata = new List < SchemaShape . Column > ( ) ;
629- if ( col . Metadata . TryFindColumn ( MetadataUtils . Kinds . SlotNames , out var slotMeta ) )
630- metadata . Add ( slotMeta ) ;
631- // TODO: check type of feature column.
632-
633- // TODO: How do we deal with multiclassScoreColumn? should also contain slotnames
634613 // Add Score column.
614+ var scoreMetadata = new List < SchemaShape . Column > ( ) ;
615+ // If multiclass, there could be a SlotNames metadata column, so it is added to the score column metadata in case.
616+ if ( _predictor . PredictionKind == PredictionKind . MultiClassClassification )
617+ scoreMetadata . Add ( new SchemaShape . Column ( MetadataUtils . Kinds . SlotNames , SchemaShape . Column . VectorKind . Vector , TextType . Instance , false ) ) ;
618+ // For some trainers the output could be normalized, but it cannot be known it given the information available here, so it is added in case.
619+ scoreMetadata . AddRange ( MetadataUtils . GetTrainerOutputMetadata ( isNormalized : true ) ) ;
635620 result [ DefaultColumnNames . Score ] = new SchemaShape . Column ( DefaultColumnNames . Score , SchemaShape . Column . VectorKind . Scalar , NumberType . R4 ,
636- false , new SchemaShape ( MetadataUtils . GetTrainerOutputMetadata ( ) ) ) ;
621+ false , new SchemaShape ( scoreMetadata ) ) ;
637622
638623 // Add FeatureContributions column.
639- if ( _args . Stringify )
624+ if ( ! inputSchema . TryFindColumn ( _featureColumn , out var col ) )
625+ throw Host . ExceptSchemaMismatch ( nameof ( inputSchema ) , "input" , _featureColumn ) ;
626+
627+ // TODO: check type of feature column.
628+ if ( _stringify )
640629 {
641630 result [ DefaultColumnNames . FeatureContributions ] = new SchemaShape . Column ( DefaultColumnNames . FeatureContributions , col . Kind ,
642- TextType . Instance , false , new SchemaShape ( metadata . ToArray ( ) ) ) ;
631+ TextType . Instance , false , null ) ;
643632 }
644633 else
645634 {
635+ var featContributionMetadata = new List < SchemaShape . Column > ( ) ;
636+ if ( col . Metadata . TryFindColumn ( MetadataUtils . Kinds . SlotNames , out var slotMeta ) )
637+ featContributionMetadata . Add ( slotMeta ) ;
646638 result [ DefaultColumnNames . FeatureContributions ] = new SchemaShape . Column ( DefaultColumnNames . FeatureContributions , col . Kind ,
647- col . ItemType , false , new SchemaShape ( metadata . ToArray ( ) ) ) ;
639+ col . ItemType , false , new SchemaShape ( featContributionMetadata . ToArray ( ) ) ) ;
648640 }
649641
650642 return new SchemaShape ( result . Values ) ;
0 commit comments