Skip to content

Commit 337cc55

Browse files
authored
Update of FeatureContributionCalculation to new API (#1847)
1 parent dea6c02 commit 337cc55

File tree

51 files changed

+3021
-221
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+3021
-221
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/FeatureContributionCalculationTransform.cs

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,36 @@ public static void FeatureContributionCalculationTransform_Regression()
4848
var transformPipeline = mlContext.Transforms.Concatenate("Features", "CrimesPerCapita", "PercentResidental",
4949
"PercentNonRetail", "CharlesRiver", "NitricOxides", "RoomsPerDwelling", "PercentPre40s",
5050
"EmploymentDistance", "HighwayDistance", "TaxRate", "TeacherRatio");
51-
var learner = mlContext.Regression.Trainers.StochasticDualCoordinateAscent(
51+
var learner = mlContext.Regression.Trainers.OrdinaryLeastSquares(
5252
labelColumn: "MedianHomeValue", featureColumn: "Features");
5353

5454
var transformedData = transformPipeline.Fit(data).Transform(data);
5555

56+
// Now we train the model and score it on the transformed data.
5657
var model = learner.Fit(transformedData);
58+
var scoredData = model.Transform(transformedData);
5759

5860
// Create a Feature Contribution Calculator
59-
// Calculate the feature contributions for all features
61+
// Calculate the feature contributions for all features given trained model parameters
6062
// And don't normalize the contribution scores
61-
var args = new FeatureContributionCalculationTransform.Arguments()
62-
{
63-
Top = 11,
64-
Normalize = false
65-
};
66-
var featureContributionCalculator = FeatureContributionCalculationTransform.Create(mlContext, args, transformedData, model.Model, model.FeatureColumn);
63+
var featureContributionCalculator = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11, normalize: false);
64+
var outputData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
65+
66+
// FeatureContributionCalculatingEstimator can be use as an intermediary step in a pipeline.
67+
// The features retained by FeatureContributionCalculatingEstimator will be in the FeatureContribution column.
68+
var pipeline = mlContext.Model.Explainability.FeatureContributionCalculation(model.Model, model.FeatureColumn, top: 11)
69+
.Append(mlContext.Regression.Trainers.OrdinaryLeastSquares(featureColumn: "FeatureContributions"));
70+
var outData = featureContributionCalculator.Fit(scoredData).Transform(scoredData);
6771

6872
// Let's extract the weights from the linear model to use as a comparison
6973
var weights = new VBuffer<float>();
7074
model.Model.GetFeatureWeights(ref weights);
7175

7276
// Let's now walk through the first ten reconds and see which feature drove the values the most
7377
// Get prediction scores and contributions
74-
var scoringEnumerator = featureContributionCalculator.AsEnumerable<HousingRegressionScoreAndContribution>(mlContext, true).GetEnumerator();
78+
var scoringEnumerator = outputData.AsEnumerable<HousingRegressionScoreAndContribution>(mlContext, true).GetEnumerator();
7579
int index = 0;
76-
Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution\tPercent");
80+
Console.WriteLine("Label\tScore\tBiggestFeature\tValue\tWeight\tContribution");
7781
while (scoringEnumerator.MoveNext() && index < 10)
7882
{
7983
var row = scoringEnumerator.Current;
@@ -84,26 +88,34 @@ public static void FeatureContributionCalculationTransform_Regression()
8488
// And the corresponding information about the feature
8589
var value = row.Features[featureOfInterest];
8690
var contribution = row.FeatureContributions[featureOfInterest];
87-
var percentContribution = 100 * contribution / row.Score;
88-
var name = data.Schema[(int) (featureOfInterest + 1)].Name;
91+
var name = data.Schema[featureOfInterest + 1].Name;
8992
var weight = weights.GetValues()[featureOfInterest];
9093

91-
Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}\t{6:0.00}",
94+
Console.WriteLine("{0:0.00}\t{1:0.00}\t{2}\t{3:0.00}\t{4:0.00}\t{5:0.00}",
9295
row.MedianHomeValue,
9396
row.Score,
9497
name,
9598
value,
9699
weight,
97-
contribution,
98-
percentContribution
100+
contribution
99101
);
100102

101103
index++;
102104
}
103-
104-
// For bulk scoring, the ApplyToData API can also be used
105-
var scoredData = featureContributionCalculator.ApplyToData(mlContext, transformedData);
106-
var preview = scoredData.Preview(100);
105+
Console.ReadLine();
106+
107+
// The output of the above code is:
108+
// Label Score BiggestFeature Value Weight Contribution
109+
// 24.00 27.74 RoomsPerDwelling 6.58 98.55 39.95
110+
// 21.60 23.85 RoomsPerDwelling 6.42 98.55 39.01
111+
// 34.70 29.29 RoomsPerDwelling 7.19 98.55 43.65
112+
// 33.40 27.17 RoomsPerDwelling 7.00 98.55 42.52
113+
// 36.20 27.68 RoomsPerDwelling 7.15 98.55 43.42
114+
// 28.70 23.13 RoomsPerDwelling 6.43 98.55 39.07
115+
// 22.90 22.71 RoomsPerDwelling 6.01 98.55 36.53
116+
// 27.10 21.72 RoomsPerDwelling 6.17 98.55 37.50
117+
// 16.50 18.04 RoomsPerDwelling 5.63 98.55 34.21
118+
// 18.90 20.14 RoomsPerDwelling 6.00 98.55 36.48
107119
}
108120

109121
private static int GetMostContributingFeature(float[] featureContributions)

docs/samples/Microsoft.ML.Samples/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ internal static class Program
66
{
77
static void Main(string[] args)
88
{
9-
TensorFlowTransformExample.TensorFlowScoringSample();
9+
FeatureContributionCalculationTransform_RegressionExample.FeatureContributionCalculationTransform_Regression();
1010
}
1111
}
1212
}

src/Microsoft.ML.Core/Data/MetadataUtils.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,11 @@ internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex
448448
if (!(schema[colIndex].Type is VectorType vecType && vecType.Size > 0))
449449
return isValid;
450450

451-
var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.CategoricalSlotRanges)?.Type;
451+
var type = schema[colIndex].Metadata.Schema.GetColumnOrNull(Kinds.CategoricalSlotRanges)?.Type;
452452
if (type?.RawType == typeof(VBuffer<int>))
453453
{
454454
VBuffer<int> catIndices = default(VBuffer<int>);
455-
schema[colIndex].Metadata.GetValue(MetadataUtils.Kinds.CategoricalSlotRanges, ref catIndices);
455+
schema[colIndex].Metadata.GetValue(Kinds.CategoricalSlotRanges, ref catIndices);
456456
VBufferUtils.Densify(ref catIndices);
457457
int columnSlotsCount = vecType.Size;
458458
if (catIndices.Length > 0 && catIndices.Length % 2 == 0 && catIndices.Length <= columnSlotsCount * 2)
@@ -498,14 +498,15 @@ internal static bool TryGetCategoricalFeatureIndices(Schema schema, int colIndex
498498
}
499499

500500
/// <summary>
501-
/// Produces sequence of columns that are generated by multiclass trainer estimators.
501+
/// Produces metadata for the score column generated by trainer estimators for multiclass classification.
502+
/// If input LabelColumn is not available it produces slotnames metadata by default.
502503
/// </summary>
503504
/// <param name="labelColumn">Label column.</param>
504505
[BestFriend]
505-
internal static IEnumerable<SchemaShape.Column> MetadataForMulticlassScoreColumn(SchemaShape.Column labelColumn)
506+
internal static IEnumerable<SchemaShape.Column> MetadataForMulticlassScoreColumn(SchemaShape.Column? labelColumn = null)
506507
{
507508
var cols = new List<SchemaShape.Column>();
508-
if (labelColumn.IsKey && HasKeyValues(labelColumn))
509+
if (labelColumn != null && labelColumn.Value.IsKey && HasKeyValues(labelColumn.Value))
509510
cols.Add(new SchemaShape.Column(Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false));
510511
cols.AddRange(GetTrainerOutputMetadata());
511512
return cols;

src/Microsoft.ML.Data/Dirty/PredictorBase.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using Float = System.Single;
66

77
using System;
8-
using Microsoft.ML.Runtime;
98
using Microsoft.ML.Runtime.Model;
109

1110
namespace Microsoft.ML.Runtime.Internal.Internallearn

src/Microsoft.ML.Data/Dirty/PredictorInterfaces.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,23 @@ internal interface IFeatureContributionMapper : IPredictor
203203
ValueMapper<TSrc, VBuffer<float>> GetFeatureContributionMapper<TSrc, TDst>(int top, int bottom, bool normalize);
204204
}
205205

206+
/// <summary>
207+
/// Allows support for feature contribution calculation.
208+
/// </summary>
209+
public interface ICalculateFeatureContribution : IPredictor
210+
{
211+
FeatureContributionCalculator FeatureContributionClaculator { get; }
212+
}
213+
214+
/// <summary>
215+
/// Support for feature contribution calculation.
216+
/// </summary>
217+
public sealed class FeatureContributionCalculator
218+
{
219+
internal IFeatureContributionMapper ContributionMapper { get; }
220+
internal FeatureContributionCalculator(IFeatureContributionMapper contributionMapper) => ContributionMapper = contributionMapper;
221+
}
222+
206223
/// <summary>
207224
/// Interface for predictors that can return a string array containing the label names from the label column they were trained on.
208225
/// If the training label is a key with text key value metadata, it should return this metadata. The order of the labels should be consistent

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,25 @@ public sealed class ModelOperationsCatalog
1616
{
1717
internal IHostEnvironment Environment { get; }
1818

19+
public ExplainabilityTransforms Explainability { get; }
20+
1921
internal ModelOperationsCatalog(IHostEnvironment env)
2022
{
2123
Contracts.AssertValue(env);
2224
Environment = env;
25+
26+
Explainability = new ExplainabilityTransforms(this);
27+
}
28+
29+
public abstract class SubCatalogBase
30+
{
31+
internal IHostEnvironment Environment { get; }
32+
33+
protected SubCatalogBase(ModelOperationsCatalog owner)
34+
{
35+
Environment = owner.Environment;
36+
}
37+
2338
}
2439

2540
/// <summary>
@@ -36,6 +51,16 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3651
/// <returns>The loaded model.</returns>
3752
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(Environment, stream);
3853

54+
/// <summary>
55+
/// The catalog of model explainability operations.
56+
/// </summary>
57+
public sealed class ExplainabilityTransforms : SubCatalogBase
58+
{
59+
internal ExplainabilityTransforms(ModelOperationsCatalog owner) : base(owner)
60+
{
61+
}
62+
}
63+
3964
/// <summary>
4065
/// Create a prediction engine for one-time prediction.
4166
/// </summary>

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ protected static ICalibrator GetCalibrator(IHostEnvironment env, ModelLoadContex
206206
}
207207
}
208208

209-
public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBase, IValueMapperDist, IFeatureContributionMapper,
209+
public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBase, IValueMapperDist, IFeatureContributionMapper, ICalculateFeatureContribution,
210210
IDistCanSavePfa, IDistCanSaveOnnx
211211
{
212212
private readonly IValueMapper _mapper;
@@ -216,6 +216,9 @@ public abstract class ValueMapperCalibratedPredictorBase : CalibratedPredictorBa
216216
ColumnType IValueMapper.OutputType => _mapper.OutputType;
217217
ColumnType IValueMapperDist.DistType => NumberType.Float;
218218
bool ICanSavePfa.CanSavePfa => (_mapper as ICanSavePfa)?.CanSavePfa == true;
219+
220+
public FeatureContributionCalculator FeatureContributionClaculator => new FeatureContributionCalculator(this);
221+
219222
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (_mapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
220223

221224
protected ValueMapperCalibratedPredictorBase(IHostEnvironment env, string name, IPredictorProducing<float> predictor, ICalibrator calibrator)

0 commit comments

Comments
 (0)