Skip to content

Commit 297a677

Browse files
committed
use the contravariance of ISingleFeaturePredictionTransformer instead of loading PredictionTransformer<object> from file
1 parent 1d5a4f0 commit 297a677

File tree

6 files changed

+52
-43
lines changed

6 files changed

+52
-43
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@
6161
"Calibrated Predictor Executor",
6262
ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature, "BulkCaliPredExec")]
6363

64-
[assembly: LoadableClass(typeof(FeatureWeightsCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
64+
[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(FeatureWeightsCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
6565
"Feature Weights Calibrated Predictor Executor",
6666
FeatureWeightsCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)]
6767

68-
[assembly: LoadableClass(typeof(ParameterMixingCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
68+
[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(ParameterMixingCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
6969
"Parameter Mixing Calibrated Predictor Executor",
7070
ParameterMixingCalibratedModelParameters<IPredictorWithFeatureWeights<float>, ICalibrator>.LoaderSignature)]
7171

72-
[assembly: LoadableClass(typeof(SchemaBindableCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
72+
[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
7373
"Schema Bindable Calibrated Predictor", SchemaBindableCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature)]
7474

7575
[assembly: LoadableClass(typeof(void), typeof(Calibrate), null, typeof(SignatureEntryPointModule), "Calibrate")]
@@ -490,7 +490,7 @@ private FeatureWeightsCalibratedModelParameters(IHostEnvironment env, ModelLoadC
490490
_featureWeights = (IPredictorWithFeatureWeights<float>)SubModel;
491491
}
492492

493-
private static FeatureWeightsCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx)
493+
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
494494
{
495495
Contracts.CheckValue(env, nameof(env));
496496
env.CheckValue(ctx, nameof(ctx));
@@ -557,7 +557,7 @@ private ParameterMixingCalibratedModelParameters(IHostEnvironment env, ModelLoad
557557
_featureWeights = SubModel as IPredictorWithFeatureWeights<float>;
558558
}
559559

560-
private static ParameterMixingCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx)
560+
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
561561
{
562562
Contracts.CheckValue(env, nameof(env));
563563
env.CheckValue(ctx, nameof(ctx));
@@ -729,7 +729,7 @@ private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadC
729729
_featureContribution = SubModel as IFeatureContributionMapper;
730730
}
731731

732-
private static SchemaBindableCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx)
732+
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
733733
{
734734
Contracts.CheckValue(ctx, nameof(ctx));
735735
ctx.CheckAtModel(GetVersionInfo());

src/Microsoft.ML.Data/Prediction/IPredictionTransformer.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace Microsoft.ML
1414
/// </summary>
1515
/// <typeparam name="TModel">The <see cref="IPredictor"/> or <see cref="ICalibrator"/> used for the data transformation.</typeparam>
1616
public interface IPredictionTransformer<out TModel> : ITransformer
17+
where TModel : class
1718
{
1819
TModel Model { get; }
1920
}
@@ -25,6 +26,7 @@ public interface IPredictionTransformer<out TModel> : ITransformer
2526
/// </summary>
2627
/// <typeparam name="TModel">The <see cref="IPredictor"/> or <see cref="ICalibrator"/> used for the data transformation.</typeparam>
2728
public interface ISingleFeaturePredictionTransformer<out TModel> : IPredictionTransformer<TModel>
29+
where TModel : class
2830
{
2931
/// <summary>The name of the feature column.</summary>
3032
string FeatureColumn { get; }

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@
88
using Microsoft.ML.Data;
99
using Microsoft.ML.Data.IO;
1010

11-
[assembly: LoadableClass(typeof(BinaryPredictionTransformer<object>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel),
11+
[assembly: LoadableClass(typeof(BinaryPredictionTransformer<IPredictorProducing<float>>), typeof(BinaryPredictionTransformer), null, typeof(SignatureLoadModel),
1212
"", BinaryPredictionTransformer.LoaderSignature)]
1313

14-
[assembly: LoadableClass(typeof(MulticlassPredictionTransformer<object>), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel),
14+
[assembly: LoadableClass(typeof(MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>), typeof(MulticlassPredictionTransformer), null, typeof(SignatureLoadModel),
1515
"", MulticlassPredictionTransformer.LoaderSignature)]
1616

17-
[assembly: LoadableClass(typeof(RegressionPredictionTransformer<object>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel),
17+
[assembly: LoadableClass(typeof(RegressionPredictionTransformer<IPredictorProducing<float>>), typeof(RegressionPredictionTransformer), null, typeof(SignatureLoadModel),
1818
"", RegressionPredictionTransformer.LoaderSignature)]
1919

20-
[assembly: LoadableClass(typeof(RankingPredictionTransformer<object>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel),
20+
[assembly: LoadableClass(typeof(RankingPredictionTransformer<IPredictorProducing<float>>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel),
2121
"", RankingPredictionTransformer.LoaderSignature)]
2222

23-
[assembly: LoadableClass(typeof(AnomalyPredictionTransformer<object>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel),
23+
[assembly: LoadableClass(typeof(AnomalyPredictionTransformer<IPredictorProducing<float>>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel),
2424
"", AnomalyPredictionTransformer.LoaderSignature)]
2525

26-
[assembly: LoadableClass(typeof(ClusteringPredictionTransformer<object>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel),
26+
[assembly: LoadableClass(typeof(ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>), typeof(ClusteringPredictionTransformer), null, typeof(SignatureLoadModel),
2727
"", ClusteringPredictionTransformer.LoaderSignature)]
2828

2929
namespace Microsoft.ML.Data
@@ -51,7 +51,8 @@ public abstract class PredictionTransformerBase<TModel> : IPredictionTransformer
5151
private protected readonly IHost Host;
5252
[BestFriend]
5353
private protected ISchemaBindableMapper BindableMapper;
54-
protected DataViewSchema TrainSchema;
54+
[BestFriend]
55+
private protected DataViewSchema TrainSchema;
5556

5657
/// <summary>
5758
/// Whether a call to <see cref="ITransformer.GetRowToRowMapper(DataViewSchema)"/> should succeed, on an
@@ -141,7 +142,8 @@ IRowToRowMapper ITransformer.GetRowToRowMapper(DataViewSchema inputSchema)
141142

142143
private protected abstract void SaveModel(ModelSaveContext ctx);
143144

144-
protected void SaveModelCore(ModelSaveContext ctx)
145+
[BestFriend]
146+
private protected void SaveModelCore(ModelSaveContext ctx)
145147
{
146148
// *** Binary format ***
147149
// <base info>
@@ -233,14 +235,14 @@ public sealed override DataViewSchema GetOutputSchema(DataViewSchema inputSchema
233235
return Transform(new EmptyDataView(Host, inputSchema)).Schema;
234236
}
235237

236-
private protected override void SaveModel(ModelSaveContext ctx)
238+
private protected sealed override void SaveModel(ModelSaveContext ctx)
237239
{
238240
Host.CheckValue(ctx, nameof(ctx));
239241
ctx.CheckAtModel();
240242
SaveCore(ctx);
241243
}
242244

243-
protected virtual void SaveCore(ModelSaveContext ctx)
245+
private protected virtual void SaveCore(ModelSaveContext ctx)
244246
{
245247
SaveModelCore(ctx);
246248
ctx.SaveStringOrNull(FeatureColumn);
@@ -295,7 +297,7 @@ private void SetScorer()
295297
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
296298
}
297299

298-
protected override void SaveCore(ModelSaveContext ctx)
300+
private protected override void SaveCore(ModelSaveContext ctx)
299301
{
300302
Contracts.AssertValue(ctx);
301303
ctx.SetVersionInfo(GetVersionInfo());
@@ -364,7 +366,7 @@ private void SetScorer()
364366
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
365367
}
366368

367-
protected override void SaveCore(ModelSaveContext ctx)
369+
private protected override void SaveCore(ModelSaveContext ctx)
368370
{
369371
Contracts.AssertValue(ctx);
370372
ctx.SetVersionInfo(GetVersionInfo());
@@ -428,7 +430,7 @@ private void SetScorer()
428430
Scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
429431
}
430432

431-
protected override void SaveCore(ModelSaveContext ctx)
433+
private protected override void SaveCore(ModelSaveContext ctx)
432434
{
433435
Contracts.AssertValue(ctx);
434436
ctx.SetVersionInfo(GetVersionInfo());
@@ -473,7 +475,7 @@ internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext
473475
Scorer = GetGenericScorer();
474476
}
475477

476-
protected override void SaveCore(ModelSaveContext ctx)
478+
private protected override void SaveCore(ModelSaveContext ctx)
477479
{
478480
Contracts.AssertValue(ctx);
479481
ctx.SetVersionInfo(GetVersionInfo());
@@ -515,7 +517,7 @@ internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx
515517
Scorer = GetGenericScorer();
516518
}
517519

518-
protected override void SaveCore(ModelSaveContext ctx)
520+
private protected override void SaveCore(ModelSaveContext ctx)
519521
{
520522
Contracts.AssertValue(ctx);
521523
ctx.SetVersionInfo(GetVersionInfo());
@@ -567,7 +569,7 @@ internal ClusteringPredictionTransformer(IHostEnvironment env, ModelLoadContext
567569
Scorer = new ClusteringScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
568570
}
569571

570-
protected override void SaveCore(ModelSaveContext ctx)
572+
private protected override void SaveCore(ModelSaveContext ctx)
571573
{
572574
Contracts.AssertValue(ctx);
573575
ctx.SetVersionInfo(GetVersionInfo());
@@ -594,47 +596,47 @@ internal static class BinaryPredictionTransformer
594596
{
595597
public const string LoaderSignature = "BinaryPredXfer";
596598

597-
public static BinaryPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
598-
=> new BinaryPredictionTransformer<object>(env, ctx);
599+
public static BinaryPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
600+
=> new BinaryPredictionTransformer<IPredictorProducing<float>>(env, ctx);
599601
}
600602

601603
internal static class MulticlassPredictionTransformer
602604
{
603605
public const string LoaderSignature = "MulticlassPredXfer";
604606

605-
public static MulticlassPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
606-
=> new MulticlassPredictionTransformer<object>(env, ctx);
607+
public static MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>> Create(IHostEnvironment env, ModelLoadContext ctx)
608+
=> new MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, ctx);
607609
}
608610

609611
internal static class RegressionPredictionTransformer
610612
{
611613
public const string LoaderSignature = "RegressionPredXfer";
612614

613-
public static RegressionPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
614-
=> new RegressionPredictionTransformer<object>(env, ctx);
615+
public static RegressionPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
616+
=> new RegressionPredictionTransformer<IPredictorProducing<float>>(env, ctx);
615617
}
616618

617619
internal static class RankingPredictionTransformer
618620
{
619621
public const string LoaderSignature = "RankingPredXfer";
620622

621-
public static RankingPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
622-
=> new RankingPredictionTransformer<object>(env, ctx);
623+
public static RankingPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
624+
=> new RankingPredictionTransformer<IPredictorProducing<float>>(env, ctx);
623625
}
624626

625627
internal static class AnomalyPredictionTransformer
626628
{
627629
public const string LoaderSignature = "AnomalyPredXfer";
628630

629-
public static AnomalyPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
630-
=> new AnomalyPredictionTransformer<object>(env, ctx);
631+
public static AnomalyPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
632+
=> new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, ctx);
631633
}
632634

633635
internal static class ClusteringPredictionTransformer
634636
{
635637
public const string LoaderSignature = "ClusteringPredXfer";
636638

637-
public static ClusteringPredictionTransformer<object> Create(IHostEnvironment env, ModelLoadContext ctx)
638-
=> new ClusteringPredictionTransformer<object>(env, ctx);
639+
public static ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>> Create(IHostEnvironment env, ModelLoadContext ctx)
640+
=> new ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, ctx);
639641
}
640642
}

src/Microsoft.ML.Transforms/PermutationFeatureImportance.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
namespace Microsoft.ML.Transforms
1717
{
18-
internal static class PermutationFeatureImportance<TModel, TMetric, TResult> where TResult : MetricsStatisticsBase<TMetric>, new()
18+
internal static class PermutationFeatureImportance<TModel, TMetric, TResult>
19+
where TResult : MetricsStatisticsBase<TMetric>, new()
20+
where TModel : class
1921
{
2022
public static ImmutableArray<TResult>
2123
GetImportanceMetricsMatrix(

src/Microsoft.ML.Transforms/PermutationFeatureImportanceExtensions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public static ImmutableArray<RegressionMetricsStatistics>
6363
string features = DefaultColumnNames.Features,
6464
bool useFeatureWeightFilter = false,
6565
int? topExamples = null,
66-
int permutationCount = 1)
66+
int permutationCount = 1) where TModel : class
6767
{
6868
return PermutationFeatureImportance<TModel, RegressionMetrics, RegressionMetricsStatistics>.GetImportanceMetricsMatrix(
6969
catalog.GetEnvironment(),
@@ -140,7 +140,7 @@ public static ImmutableArray<BinaryClassificationMetricsStatistics>
140140
string features = DefaultColumnNames.Features,
141141
bool useFeatureWeightFilter = false,
142142
int? topExamples = null,
143-
int permutationCount = 1)
143+
int permutationCount = 1) where TModel : class
144144
{
145145
return PermutationFeatureImportance<TModel, BinaryClassificationMetrics, BinaryClassificationMetricsStatistics>.GetImportanceMetricsMatrix(
146146
catalog.GetEnvironment(),
@@ -214,7 +214,7 @@ public static ImmutableArray<MultiClassClassifierMetricsStatistics>
214214
string features = DefaultColumnNames.Features,
215215
bool useFeatureWeightFilter = false,
216216
int? topExamples = null,
217-
int permutationCount = 1)
217+
int permutationCount = 1) where TModel : class
218218
{
219219
return PermutationFeatureImportance<TModel, MultiClassClassifierMetrics, MultiClassClassifierMetricsStatistics>.GetImportanceMetricsMatrix(
220220
catalog.GetEnvironment(),
@@ -295,7 +295,7 @@ public static ImmutableArray<RankingMetricsStatistics>
295295
string features = DefaultColumnNames.Features,
296296
bool useFeatureWeightFilter = false,
297297
int? topExamples = null,
298-
int permutationCount = 1)
298+
int permutationCount = 1) where TModel : class
299299
{
300300
return PermutationFeatureImportance<TModel, RankingMetrics, RankingMetricsStatistics>.GetImportanceMetricsMatrix(
301301
catalog.GetEnvironment(),

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void LoadModelAndExtractPredictor()
4848
loadedModel = ml.Model.Load(fs);
4949

5050
var gam = (((loadedModel as TransformerChain<ITransformer>).LastTransformer
51-
as BinaryPredictionTransformer<object>).Model
51+
as ISingleFeaturePredictionTransformer<object>).Model
5252
as CalibratedModelParametersBase).SubModel
5353
as BinaryClassificationGamModelParameters;
5454
Assert.NotNull(gam);
@@ -92,9 +92,12 @@ public void SaveAndLoadModelWithLoader()
9292
data.Schema["Features"].GetSlotNames(ref slotNames);
9393
var ageIndex = FindIndex(slotNames.GetValues(), "age");
9494
var transformer = (loadedModel as CompositeDataLoader<IMultiStreamSource, ITransformer>).Transformer.LastTransformer;
95-
var gamModel = ((transformer as BinaryPredictionTransformer<object>).Model
96-
as CalibratedModelParametersBase).SubModel
97-
as BinaryClassificationGamModelParameters;
95+
var singleFeaturePredictionTransformer = transformer as ISingleFeaturePredictionTransformer<object>;
96+
Assert.NotNull(singleFeaturePredictionTransformer);
97+
var calibratedModelParameters = singleFeaturePredictionTransformer.Model as CalibratedModelParametersBase;
98+
Assert.NotNull(calibratedModelParameters);
99+
var gamModel = calibratedModelParameters.SubModel as BinaryClassificationGamModelParameters;
100+
Assert.NotNull(gamModel);
98101
var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex);
99102
var ageBinEffects = gamModel.GetBinEffects(ageIndex);
100103
}

0 commit comments

Comments
 (0)