Skip to content

Commit 7fd6709

Browse files
committed
simplified code
1 parent 4355f18 commit 7fd6709

File tree

3 files changed

+85
-105
lines changed

3 files changed

+85
-105
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,6 @@ public static void FeatureContributionCalculationTransform_Regression()
6464
Normalize = false
6565
};
6666

67-
//var featureContributionCalculator = new FeatureContributionCalculatingEstimator(mlContext, args, model.FeatureColumn, model.Model);
68-
//var outputData = featureContributionCalculator.Fit(transformedData).Transform(transformedData);
69-
70-
//var outputData1 = new FeatureContributionCalculationTransform(mlContext, new FeatureContributionCalculationTransform.Arguments() { Top = 11, Normalize = false }, model.FeatureColumn, model.Model).Transform(transformedData);
71-
7267
var featureContributionCalculator = new FeatureContributionCalculatingTransformer(mlContext, model.Model, model.FeatureColumn, args);
7368
var outputData = featureContributionCalculator.Transform(transformedData);
7469

src/Microsoft.ML.Data/Scorers/FeatureContributionCalculatingTransformer.cs

Lines changed: 76 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717
using Microsoft.ML.Runtime.Model;
1818
using Microsoft.ML.Runtime.Numeric;
1919

20-
[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(FeatureContributionCalculatingTransformer.BindableMapper), typeof(FeatureContributionCalculatingTransformer.Arguments),
20+
[assembly: LoadableClass(typeof(IDataScorerTransform), typeof(FeatureContributionCalculatingTransformer.BindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper.Arguments),
2121
typeof(SignatureDataScorer), "Feature Contribution Transform", "fct", "FeatureContributionCalculationTransform", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)]
2222

23-
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper), typeof(FeatureContributionCalculatingTransformer.Arguments),
23+
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper.Arguments),
2424
typeof(SignatureBindableMapper), "Feature Contribution Mapper", "fct", MetadataUtils.Const.ScoreColumnKind.FeatureContribution)]
2525

26-
[assembly: LoadableClass(typeof(ISchemaBindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper), null, typeof(SignatureLoadModel),
27-
"Feature Contribution Mapper", FeatureContributionCalculatingTransformer.MapperLoaderSignature)]
26+
[assembly: LoadableClass(typeof(FeatureContributionCalculatingTransformer.BindableMapper), typeof(FeatureContributionCalculatingTransformer.BindableMapper), null, typeof(SignatureLoadModel),
27+
"Feature Contribution Mapper", FeatureContributionCalculatingTransformer.BindableMapper.MapperLoaderSignature)]
2828

2929
[assembly: LoadableClass(FeatureContributionCalculatingTransformer.Summary, typeof(FeatureContributionCalculatingTransformer), null, typeof(SignatureLoadModel),
3030
FeatureContributionCalculatingTransformer.FriendlyName, FeatureContributionCalculatingTransformer.LoaderSignature)]
@@ -51,38 +51,12 @@ namespace Microsoft.ML.Runtime.Data
5151
/// </example>
5252
public sealed class FeatureContributionCalculatingTransformer : RowToRowTransformerBase
5353
{
54-
public sealed class Arguments : ScorerArgumentsBase
55-
{
56-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 1)]
57-
public int Top = 10;
58-
59-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bottom contributions", SortOrder = 2)]
60-
public int Bottom = 10;
61-
62-
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution should be normalized", ShortName = "norm", SortOrder = 3)]
63-
public bool Normalize = true;
64-
65-
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution in string key-value format", ShortName = "str", SortOrder = 4)]
66-
public bool Stringify = false;
67-
68-
// REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
69-
}
70-
7154
// Apparently, loader signature is limited in length to 24 characters.
7255
internal const string Summary = "For each data point, calculates the contribution of individual features to the model prediction.";
7356
internal const string FriendlyName = "Feature Contribution Transform";
7457
internal const string LoaderSignature = "FeatureContribution";
7558

76-
internal const string MapperLoaderSignature = "WTFBindable";
77-
78-
private const int MaxTopBottom = 1000;
79-
80-
private readonly string _features;
81-
private readonly int _topContributionsCount;
82-
private readonly int _bottomContributionsCount;
83-
private readonly bool _normalize;
84-
private readonly bool _stringify;
85-
private readonly IFeatureContributionMapper _predictor;
59+
private readonly string _featureColumn;
8660
private readonly BindableMapper _mapper;
8761

8862
private static VersionInfo GetVersionInfo()
@@ -97,23 +71,22 @@ private static VersionInfo GetVersionInfo()
9771
}
9872

9973
// TODO documentation
100-
public FeatureContributionCalculatingTransformer(IHostEnvironment env, IPredictor predictor, string featuresColumn, Arguments args)
74+
public FeatureContributionCalculatingTransformer(IHostEnvironment env, IPredictor predictor, string featureColumn,
75+
int top = FeatureContributionCalculatingEstimator.Defaults.Top,
76+
int bottom = FeatureContributionCalculatingEstimator.Defaults.Bottom,
77+
bool normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize,
78+
bool stringigy = FeatureContributionCalculatingEstimator.Defaults.Stringify)
10179
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)))
10280
{
103-
Host.CheckValue(args, nameof(args));
10481
Host.CheckValue(predictor, nameof(predictor));
82+
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
10583

10684
var pred = predictor as IFeatureContributionMapper;
10785
Host.CheckParam(pred != null, nameof(predictor), "Predictor doesn't support getting feature contributions");
10886

10987
// TODO check that the featues column is not empty.
110-
_mapper = new BindableMapper(Host, pred, args.Top, args.Bottom, args.Normalize, args.Stringify);
111-
_features = featuresColumn;
112-
_predictor = pred;
113-
_stringify = args.Stringify;
114-
_topContributionsCount = args.Top;
115-
_bottomContributionsCount = args.Bottom;
116-
_normalize = args.Normalize;
88+
_featureColumn = featureColumn;
89+
_mapper = new BindableMapper(Host, pred, top, bottom, normalize, stringigy);
11790
}
11891

11992
// Factory method for SignatureLoadModel
@@ -127,17 +100,8 @@ private FeatureContributionCalculatingTransformer(IHostEnvironment env, ModelLoa
127100
// string features
128101
// BindableMapper mapper
129102

130-
// TODO use ctx.LoadModel with BindableMapper instead of this.
131-
_features = ctx.LoadNonEmptyString();
132-
ctx.LoadModel<IFeatureContributionMapper, SignatureLoadModel>(env, out _predictor, ModelFileUtils.DirPredictor);
133-
_topContributionsCount = ctx.Reader.ReadInt32();
134-
Contracts.CheckDecode(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom);
135-
_bottomContributionsCount = ctx.Reader.ReadInt32();
136-
Contracts.CheckDecode(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom);
137-
_normalize = ctx.Reader.ReadBoolByte();
138-
_stringify = ctx.Reader.ReadBoolByte();
139-
140-
_mapper = new BindableMapper(env, _predictor, _topContributionsCount, _bottomContributionsCount, _normalize, _stringify);
103+
_featureColumn = ctx.LoadNonEmptyString();
104+
ctx.LoadModel<BindableMapper, SignatureLoadModel>(env, out _mapper, ModelFileUtils.DirPredictor);
141105
}
142106

143107
// Factory method for SignatureLoadRowMapper.
@@ -153,15 +117,8 @@ public override void Save(ModelSaveContext ctx)
153117
// string features
154118
// BindableMapper mapper
155119

156-
ctx.SaveNonEmptyString(_features);
157-
// TODO use ctx.SaveModel with BindableMapper instead of this.
158-
ctx.SaveModel(_predictor, ModelFileUtils.DirPredictor);
159-
Contracts.Assert(0 < _topContributionsCount && _topContributionsCount <= MaxTopBottom);
160-
ctx.Writer.Write(_topContributionsCount);
161-
Contracts.Assert(0 < _bottomContributionsCount && _bottomContributionsCount <= MaxTopBottom);
162-
ctx.Writer.Write(_bottomContributionsCount);
163-
ctx.Writer.WriteBoolByte(_normalize);
164-
ctx.Writer.WriteBoolByte(_stringify);
120+
ctx.SaveNonEmptyString(_featureColumn);
121+
ctx.SaveModel(_mapper, ModelFileUtils.DirPredictor);
165122
}
166123

167124
private protected override IRowMapper MakeRowMapper(Schema schema)
@@ -179,12 +136,11 @@ private class Mapper : MapperBase
179136
public Mapper(FeatureContributionCalculatingTransformer parent, Schema schema)
180137
: base(parent.Host, schema)
181138
{
182-
// TODO some checks? get soem of the columns, initialize some stuff
183139
_parent = parent;
184140
_bindableMapper = _parent._mapper;
185141

186142
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
187-
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, _parent._features));
143+
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Feature, _parent._featureColumn));
188144
_roleMappedSchema = new RoleMappedSchema(InputSchema, roles);
189145

190146
var genericMapper = _bindableMapper.GenericMapper.Bind(Host, _roleMappedSchema);
@@ -204,7 +160,7 @@ public Mapper(FeatureContributionCalculatingTransformer parent, Schema schema)
204160
private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
205161
{
206162
var active = new bool[InputSchema.ColumnCount];
207-
InputSchema.TryGetColumnIndex(_parent._features, out int featureCol);
163+
InputSchema.TryGetColumnIndex(_parent._featureColumn, out int featureCol);
208164
active[featureCol] = true;
209165
return col => active[col];
210166
}
@@ -219,8 +175,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore()
219175
var result = new List<Schema.DetachedColumn>();
220176

221177
// Add Score Column.
222-
foreach (var pair in _outputGenericSchema.GetColumns())
223-
result.Add(new Schema.DetachedColumn(pair.column));
178+
result.AddRange(_outputGenericSchema.GetColumns().Select(pair => new Schema.DetachedColumn(pair.column)));
224179

225180
// Add FeatureContributions column.
226181
var builder = new MetadataBuilder();
@@ -263,6 +218,23 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func<int, bool> act
263218
// TODO documentation
264219
internal sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPredictor
265220
{
221+
public sealed class Arguments : ScorerArgumentsBase
222+
{
223+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of top contributions", SortOrder = 1)]
224+
public int Top = FeatureContributionCalculatingEstimator.Defaults.Top;
225+
226+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bottom contributions", SortOrder = 2)]
227+
public int Bottom = FeatureContributionCalculatingEstimator.Defaults.Bottom;
228+
229+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution should be normalized", ShortName = "norm", SortOrder = 3)]
230+
public bool Normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize;
231+
232+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether or not output of Features contribution in string key-value format", ShortName = "str", SortOrder = 4)]
233+
public bool Stringify = FeatureContributionCalculatingEstimator.Defaults.Stringify;
234+
235+
// REVIEW: the scorer currently ignores the 'suffix' argument from the base class. It should respect it.
236+
}
237+
266238
private readonly int _topContributionsCount;
267239
private readonly int _bottomContributionsCount;
268240
private readonly bool _normalize;
@@ -272,6 +244,9 @@ internal sealed class BindableMapper : ISchemaBindableMapper, ICanSaveModel, IPr
272244
public readonly ISchemaBindableMapper GenericMapper;
273245
public readonly bool Stringify;
274246

247+
internal const string MapperLoaderSignature = "WTFBindable";
248+
private const int MaxTopBottom = 1000;
249+
275250
private static VersionInfo GetVersionInfo()
276251
{
277252
return new VersionInfo(
@@ -604,47 +579,64 @@ public void GetMetadata<TValue>(string kind, int col, ref TValue value)
604579
// TODO DOcumentation
605580
public sealed class FeatureContributionCalculatingEstimator : TrivialEstimator<FeatureContributionCalculatingTransformer>
606581
{
607-
private readonly FeatureContributionCalculatingTransformer.Arguments _args;
608-
private readonly string _features;
582+
private readonly string _featureColumn;
609583
private readonly IPredictor _predictor;
584+
private readonly bool _stringify;
585+
586+
public static class Defaults
587+
{
588+
public const int Top = 10;
589+
public const int Bottom = 10;
590+
public const bool Normalize = true;
591+
public const bool Stringify = false;
592+
}
610593

611594
// TODO Documentation
612-
public FeatureContributionCalculatingEstimator(IHostEnvironment env, IPredictor predictor, string featuresColumn, FeatureContributionCalculatingTransformer.Arguments args)
613-
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new FeatureContributionCalculatingTransformer(env, predictor, featuresColumn, args))
595+
public FeatureContributionCalculatingEstimator(IHostEnvironment env, IPredictor predictor, string featureColumn,
596+
int top = Defaults.Top,
597+
int bottom = Defaults.Bottom,
598+
bool normalize = Defaults.Normalize,
599+
bool stringify = Defaults.Stringify)
600+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)),
601+
new FeatureContributionCalculatingTransformer(env, predictor, featureColumn, top, bottom, normalize, stringify))
614602
{
615-
// TODO argcheck
616-
_args = args;
617-
_features = featuresColumn;
603+
_featureColumn = featureColumn;
618604
_predictor = predictor;
605+
_stringify = stringify;
619606
}
620607

621608
public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
622609
{
623610
Host.CheckValue(inputSchema, nameof(inputSchema));
624611
var result = inputSchema.ToDictionary(x => x.Name);
625612

626-
if (!inputSchema.TryFindColumn(_features, out var col))
627-
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _features);
628-
var metadata = new List<SchemaShape.Column>();
629-
if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
630-
metadata.Add(slotMeta);
631-
// TODO: check type of feature column.
632-
633-
// TODO: How do we deal with multiclassScoreColumn? should also contain slotnames
634613
// Add Score column.
614+
var scoreMetadata = new List<SchemaShape.Column>();
615+
// If multiclass, there could be a SlotNames metadata column, so it is added to the score column metadata in case.
616+
if (_predictor.PredictionKind == PredictionKind.MultiClassClassification)
617+
scoreMetadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
618+
// For some trainers the output could be normalized, but it cannot be known it given the information available here, so it is added in case.
619+
scoreMetadata.AddRange(MetadataUtils.GetTrainerOutputMetadata(isNormalized: true));
635620
result[DefaultColumnNames.Score] = new SchemaShape.Column(DefaultColumnNames.Score, SchemaShape.Column.VectorKind.Scalar, NumberType.R4,
636-
false, new SchemaShape(MetadataUtils.GetTrainerOutputMetadata()));
621+
false, new SchemaShape(scoreMetadata));
637622

638623
// Add FeatureContributions column.
639-
if (_args.Stringify)
624+
if (!inputSchema.TryFindColumn(_featureColumn, out var col))
625+
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _featureColumn);
626+
627+
// TODO: check type of feature column.
628+
if (_stringify)
640629
{
641630
result[DefaultColumnNames.FeatureContributions] = new SchemaShape.Column(DefaultColumnNames.FeatureContributions, col.Kind,
642-
TextType.Instance, false, new SchemaShape(metadata.ToArray()));
631+
TextType.Instance, false, null);
643632
}
644633
else
645634
{
635+
var featContributionMetadata = new List<SchemaShape.Column>();
636+
if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta))
637+
featContributionMetadata.Add(slotMeta);
646638
result[DefaultColumnNames.FeatureContributions] = new SchemaShape.Column(DefaultColumnNames.FeatureContributions, col.Kind,
647-
col.ItemType, false, new SchemaShape(metadata.ToArray()));
639+
col.ItemType, false, new SchemaShape(featContributionMetadata.ToArray()));
648640
}
649641

650642
return new SchemaShape(result.Values);

0 commit comments

Comments
 (0)