Skip to content

Commit 1d5a4f0

Browse files
committed
Address some code review comments, add a non-generic base class for calibrated predictor
1 parent 495740e commit 1d5a4f0

File tree

4 files changed

+23
-32
lines changed

4 files changed

+23
-32
lines changed

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,29 +98,6 @@ void ICanSaveModel.Save(ModelSaveContext ctx)
9898
ctx.SaveModel(Transformer, TransformerDirectory);
9999
}
100100

101-
/// <summary>
102-
/// Save the contents to a stream, as a "model file".
103-
/// </summary>
104-
public void SaveTo(IHostEnvironment env, Stream outputStream)
105-
{
106-
Contracts.CheckValue(env, nameof(env));
107-
env.CheckValue(outputStream, nameof(outputStream));
108-
109-
env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
110-
using (var ch = env.Start("Saving pipeline"))
111-
{
112-
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
113-
{
114-
ch.Trace("Saving data loader");
115-
ModelSaveContext.SaveModel(rep, Loader, LoaderDirectory);
116-
117-
ch.Trace("Saving transformer chain");
118-
ModelSaveContext.SaveModel(rep, Transformer, TransformerDirectory);
119-
rep.Commit();
120-
}
121-
}
122-
}
123-
124101
internal const string Summary = "A loader that encapsulates a loader and a transformer chain.";
125102

126103
internal const string LoaderSignature = "CompositeLoader";

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Strea
5151
/// <returns>The loaded model.</returns>
5252
public ITransformer Load(Stream stream) => TransformerChain.LoadFrom(_env, stream);
5353

54-
public IDataLoader<IMultiStreamSource> LoadAsCompositeDataLoader(Stream stream)
54+
public CompositeDataLoader<IMultiStreamSource, ITransformer> LoadAsCompositeDataLoader(Stream stream)
5555
{
5656
using (var rep = RepositoryReader.Open(stream))
5757
{
58-
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, "Model");
58+
ModelLoadContext.LoadModel<CompositeDataLoader<IMultiStreamSource, ITransformer>, SignatureLoadModel>(_env, out var model, rep, "Model");
5959
return model;
6060
}
6161
}

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"Naive Calibration Executor",
5858
NaiveCalibrator.LoaderSignature)]
5959

60-
[assembly: LoadableClass(typeof(ValueMapperCalibratedModelParameters<object, ICalibrator>), null, typeof(SignatureLoadModel),
60+
[assembly: LoadableClass(typeof(CalibratedModelParametersBase), typeof(ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>), null, typeof(SignatureLoadModel),
6161
"Calibrated Predictor Executor",
6262
ValueMapperCalibratedModelParameters<IPredictorProducing<float>, ICalibrator>.LoaderSignature, "BulkCaliPredExec")]
6363

@@ -150,6 +150,18 @@ internal interface IWeaklyTypedCalibratedModelParameters
150150
ICalibrator WeaklyTypedCalibrator { get; }
151151
}
152152

153+
public abstract class CalibratedModelParametersBase
154+
{
155+
public object SubModel { get; }
156+
public ICalibrator Calibrator { get; }
157+
158+
private protected CalibratedModelParametersBase(object subModel, ICalibrator calibrator)
159+
{
160+
SubModel = subModel;
161+
Calibrator = calibrator;
162+
}
163+
}
164+
153165
/// <summary>
154166
/// Class for allowing a post-processing step, defined by <see cref="Calibrator"/>, to <see cref="SubModel"/>'s
155167
/// output.
@@ -161,7 +173,7 @@ internal interface IWeaklyTypedCalibratedModelParameters
161173
/// output value to the probability of belonging to the positive (or negative) class. Detailed math materials
162174
/// can be found at <a href="https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf">this paper</a>.
163175
/// </remarks>
164-
public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> :
176+
public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> : CalibratedModelParametersBase,
165177
IDistPredictorProducing<float, float>,
166178
ICanSaveInIniFormat,
167179
ICanSaveInTextFormat,
@@ -178,11 +190,12 @@ public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> :
178190
/// <summary>
179191
/// <see cref="SubModel"/>'s output would calibrated by <see cref="Calibrator"/>.
180192
/// </summary>
181-
public TSubModel SubModel { get; }
193+
public new TSubModel SubModel { get; }
194+
182195
/// <summary>
183196
/// <see cref="Calibrator"/> is used to post-process score produced by <see cref="SubModel"/>.
184197
/// </summary>
185-
public TCalibrator Calibrator { get; }
198+
public new TCalibrator Calibrator { get; }
186199

187200
// Type-unsafed accessors of strongly-typed members.
188201
IPredictorProducing<float> IWeaklyTypedCalibratedModelParameters.WeaklyTypedSubModel => (IPredictorProducing<float>)SubModel;
@@ -191,6 +204,7 @@ public abstract class CalibratedModelParametersBase<TSubModel, TCalibrator> :
191204
PredictionKind IPredictor.PredictionKind => ((IPredictorProducing<float>)SubModel).PredictionKind;
192205

193206
private protected CalibratedModelParametersBase(IHostEnvironment env, string name, TSubModel predictor, TCalibrator calibrator)
207+
: base(predictor, calibrator)
194208
{
195209
Contracts.CheckValue(env, nameof(env));
196210
env.CheckNonWhiteSpace(name, nameof(name));
@@ -417,7 +431,7 @@ private ValueMapperCalibratedModelParameters(IHostEnvironment env, ModelLoadCont
417431
{
418432
}
419433

420-
private static ValueMapperCalibratedModelParameters<TSubModel, TCalibrator> Create(IHostEnvironment env, ModelLoadContext ctx)
434+
private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
421435
{
422436
Contracts.CheckValue(ctx, nameof(ctx));
423437
// Can load either the old "bulk" model or standard "cali". The two formats are identical.

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DeserializationTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void LoadModelAndExtractPredictor()
4949

5050
var gam = (((loadedModel as TransformerChain<ITransformer>).LastTransformer
5151
as BinaryPredictionTransformer<object>).Model
52-
as CalibratedModelParametersBase<object, ICalibrator>).SubModel
52+
as CalibratedModelParametersBase).SubModel
5353
as BinaryClassificationGamModelParameters;
5454
Assert.NotNull(gam);
5555
}
@@ -93,7 +93,7 @@ public void SaveAndLoadModelWithLoader()
9393
var ageIndex = FindIndex(slotNames.GetValues(), "age");
9494
var transformer = (loadedModel as CompositeDataLoader<IMultiStreamSource, ITransformer>).Transformer.LastTransformer;
9595
var gamModel = ((transformer as BinaryPredictionTransformer<object>).Model
96-
as CalibratedModelParametersBase<object, ICalibrator>).SubModel
96+
as CalibratedModelParametersBase).SubModel
9797
as BinaryClassificationGamModelParameters;
9898
var ageBinUpperBounds = gamModel.GetBinUpperBounds(ageIndex);
9999
var ageBinEffects = gamModel.GetBinEffects(ageIndex);

0 commit comments

Comments
 (0)