Skip to content

Commit fcd98d1

Browse files
committed
Add unit tests, and address some code review comments
1 parent ec8ad0d commit fcd98d1

File tree

13 files changed

+297
-115
lines changed

13 files changed

+297
-115
lines changed

src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -251,54 +251,41 @@ private static TransformerChain<ITransformer> Create(IHostEnvironment env, Model
251251
public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream)
252252
=> new TransformerChain<ITransformer>(transformer).SaveTo(env, outputStream);
253253

254-
public static ITransformer LoadFrom(IHostEnvironment env, Stream stream)
254+
public static ITransformer LoadFromLegacy(IHostEnvironment env, Stream stream)
255255
{
256-
using (var rep = RepositoryReader.Open(stream, env))
257-
{
258-
try
259-
{
260-
ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
261-
return transformerChain;
262-
}
263-
catch (FormatException ex)
264-
{
265-
if (!ex.IsMarked())
266-
throw;
267-
var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false);
268-
TransformerChain<ITransformer> transformChain = (chain as LegacyCompositeDataLoader).GetTransformer();
269-
var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream);
270-
if (predictor == null)
271-
return transformChain;
272-
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream);
273-
env.CheckDecode(roles != null, "Predictor model must contain role mappings");
274-
var roleMappings = roles.ToArray();
275-
276-
ITransformer pred = null;
277-
if (predictor.PredictionKind == PredictionKind.BinaryClassification)
278-
pred = new BinaryPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
279-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
280-
else if (predictor.PredictionKind == PredictionKind.MulticlassClassification)
281-
pred = new MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env,
282-
predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
283-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value,
284-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value);
285-
else if (predictor.PredictionKind == PredictionKind.Clustering)
286-
pred = new ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
287-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
288-
else if (predictor.PredictionKind == PredictionKind.Regression)
289-
pred = new RegressionPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
290-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
291-
else if (predictor.PredictionKind == PredictionKind.AnomalyDetection)
292-
pred = new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
293-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
294-
else if (predictor.PredictionKind == PredictionKind.Ranking)
295-
pred = new RankingPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
296-
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
297-
else
298-
throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind);
299-
return transformChain.Append(pred);
300-
}
301-
}
256+
var chain = ModelFileUtils.LoadPipeline(env, stream, new MultiFileSource(null), extractInnerPipe: false);
257+
TransformerChain<ITransformer> transformChain = (chain as LegacyCompositeDataLoader).GetTransformer();
258+
var predictor = ModelFileUtils.LoadPredictorOrNull(env, stream);
259+
if (predictor == null)
260+
return transformChain;
261+
var roles = ModelFileUtils.LoadRoleMappingsOrNull(env, stream);
262+
env.CheckDecode(roles != null, "Predictor model must contain role mappings");
263+
var roleMappings = roles.ToArray();
264+
265+
ITransformer pred = null;
266+
if (predictor.PredictionKind == PredictionKind.BinaryClassification)
267+
pred = new BinaryPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
268+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
269+
else if (predictor.PredictionKind == PredictionKind.MulticlassClassification)
270+
pred = new MulticlassPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env,
271+
predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
272+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value,
273+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value);
274+
else if (predictor.PredictionKind == PredictionKind.Clustering)
275+
pred = new ClusteringPredictionTransformer<IPredictorProducing<VBuffer<float>>>(env, predictor as IPredictorProducing<VBuffer<float>>, chain.Schema,
276+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
277+
else if (predictor.PredictionKind == PredictionKind.Regression)
278+
pred = new RegressionPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
279+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
280+
else if (predictor.PredictionKind == PredictionKind.AnomalyDetection)
281+
pred = new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
282+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
283+
else if (predictor.PredictionKind == PredictionKind.Ranking)
284+
pred = new RankingPredictionTransformer<IPredictorProducing<float>>(env, predictor as IPredictorProducing<float>, chain.Schema,
285+
roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value);
286+
else
287+
throw env.Except("Don't know how to map prediction kind {0}", predictor.PredictionKind);
288+
return transformChain.Append(pred);
302289
}
303290
}
304291
}

src/Microsoft.ML.Data/Model/ModelOperationsCatalog.cs

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.Linq;
78
using Microsoft.Data.DataView;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.Data.IO;
@@ -32,41 +33,44 @@ internal ModelOperationsCatalog(IHostEnvironment env)
3233
Explainability = new ExplainabilityTransforms(this);
3334
}
3435

35-
private void Save<TSource>(DataViewSchema schema, IDataLoader<TSource> model, Stream stream)
36+
/// <summary>
37+
/// Save the model to the stream.
38+
/// </summary>
39+
/// <param name="model">The trained model to be saved.</param>
40+
/// <param name="stream">A writeable, seekable stream to save to.</param>
41+
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
3642
{
43+
_env.CheckValue(model, nameof(model));
44+
_env.CheckValue(stream, nameof(stream));
45+
3746
using (var rep = RepositoryWriter.CreateNew(stream))
3847
{
3948
ModelSaveContext.SaveModel(rep, model, null);
40-
SaveInputSchema(schema, rep);
4149
rep.Commit();
4250
}
4351
}
4452

45-
/// <summary>
46-
/// Save the model to the stream.
47-
/// </summary>
48-
/// <param name="model">The trained model to be saved.</param>
49-
/// <param name="stream">A writeable, seekable stream to save to.</param>
50-
public void Save<TSource>(IDataLoader<TSource> model, Stream stream)
51-
=> Save(model.GetOutputSchema(), model, stream);
52-
5353
/// <summary>
5454
/// Save a transformer model and the loader used to create its input data to the stream.
5555
/// </summary>
5656
/// <param name="loader">The loader that was used to create data to train the model</param>
5757
/// <param name="model">The trained model to be saved</param>
5858
/// <param name="stream">A writeable, seekable stream to save to.</param>
5959
public void Save<TSource>(IDataLoader<TSource> loader, ITransformer model, Stream stream) =>
60-
Save(loader.GetOutputSchema(), new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
60+
Save(new CompositeDataLoader<TSource, ITransformer>(loader, new TransformerChain<ITransformer>(model)), stream);
6161

6262
/// <summary>
6363
/// Save a transformer model and the schema of the data that was used to train it to the stream.
6464
/// </summary>
65-
/// <param name="inputSchema">The schema of the input to the transformer.</param>
6665
/// <param name="model">The trained model to be saved.</param>
66+
/// <param name="inputSchema">The schema of the input to the transformer. This can be null.</param>
6767
/// <param name="stream">A writeable, seekable stream to save to.</param>
68-
public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
68+
public void Save(ITransformer model, DataViewSchema inputSchema, Stream stream)
6969
{
70+
_env.CheckValue(model, nameof(model));
71+
_env.CheckValueOrNull(inputSchema);
72+
_env.CheckValue(stream, nameof(stream));
73+
7074
using (var rep = RepositoryWriter.CreateNew(stream))
7175
{
7276
ModelSaveContext.SaveModel(rep, model, CompositeDataLoader<object, ITransformer>.TransformerDirectory);
@@ -77,6 +81,12 @@ public void Save(DataViewSchema inputSchema, ITransformer model, Stream stream)
7781

7882
private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
7983
{
84+
_env.AssertValueOrNull(inputSchema);
85+
_env.AssertValue(rep);
86+
87+
if (inputSchema == null)
88+
return;
89+
8090
using (var ch = _env.Start("Saving Schema"))
8191
{
8292
var entry = rep.CreateEntry(SchemaEntryName);
@@ -94,30 +104,50 @@ private void SaveInputSchema(DataViewSchema inputSchema, RepositoryWriter rep)
94104
/// <returns>The loaded model.</returns>
95105
public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
96106
{
107+
_env.CheckValue(stream, nameof(stream));
108+
97109
using (var rep = RepositoryReader.Open(stream, _env))
98110
{
99111
var entry = rep.OpenEntryOrNull(SchemaEntryName);
100112
if (entry != null)
101113
{
102114
var loader = new BinaryLoader(_env, new BinaryLoader.Arguments(), entry.Stream);
103115
inputSchema = loader.Schema;
116+
ModelLoadContext.LoadModel<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
117+
CompositeDataLoader<object, ITransformer>.TransformerDirectory);
118+
return transformerChain;
104119
}
105-
else
120+
121+
ModelLoadContext.LoadModelOrNull<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var dataLoader, rep, null);
122+
if (dataLoader == null)
106123
{
124+
// Try to see if the model was saved without a loader or a schema.
125+
if (ModelLoadContext.LoadModelOrNull<ITransformer, SignatureLoadModel>(_env, out var transformerChain, rep,
126+
CompositeDataLoader<object, ITransformer>.TransformerDirectory))
127+
{
128+
inputSchema = null;
129+
return transformerChain;
130+
}
131+
107132
// Try to load from legacy model format.
108133
try
109134
{
110135
var loader = ModelFileUtils.LoadLoader(_env, rep, new MultiFileSource(null), false);
111136
inputSchema = loader.Schema;
137+
return TransformerChain.LoadFromLegacy(_env, stream);
112138
}
113139
catch (Exception ex)
114140
{
115-
if (!ex.IsMarked())
116-
throw;
117-
inputSchema = null;
141+
throw _env.Except(ex, "Could not load legacy format model");
118142
}
119143
}
120-
return TransformerChain.LoadFrom(_env, stream);
144+
if (dataLoader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
145+
{
146+
inputSchema = composite.Loader.GetOutputSchema();
147+
return composite.Transformer;
148+
}
149+
inputSchema = dataLoader.GetOutputSchema();
150+
return new TransformerChain<ITransformer>();
121151
}
122152
}
123153

@@ -127,12 +157,21 @@ public ITransformer Load(Stream stream, out DataViewSchema inputSchema)
127157
/// <param name="stream">A readable, seekable stream to load from.</param>
128158
/// <returns>A model of type <see cref="CompositeDataLoader{IMultiStreamSource, ITransformer}"/> containing the loader
129159
/// and the transformer chain.</returns>
130-
public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
160+
public IDataLoader<IMultiStreamSource> Load(Stream stream)
131161
{
162+
_env.CheckValue(stream, nameof(stream));
163+
132164
using (var rep = RepositoryReader.Open(stream))
133165
{
134-
ModelLoadContext.LoadModel<CompositeDataLoader<IMultiStreamSource, ITransformer>, SignatureLoadModel>(_env, out var model, rep, null);
135-
return model;
166+
try
167+
{
168+
ModelLoadContext.LoadModel<IDataLoader<IMultiStreamSource>, SignatureLoadModel>(_env, out var model, rep, null);
169+
return model;
170+
}
171+
catch (Exception ex)
172+
{
173+
throw _env.Except(ex, "Model does not contain an IDataLoader");
174+
}
136175
}
137176
}
138177

@@ -144,6 +183,8 @@ public CompositeDataLoader<IMultiStreamSource, ITransformer> Load(Stream stream)
144183
/// <returns>The transformer model from the model stream.</returns>
145184
public ITransformer Load(Stream stream, out IDataLoader<IMultiStreamSource> loader)
146185
{
186+
_env.CheckValue(stream, nameof(stream));
187+
147188
loader = Load(stream);
148189
if (loader is CompositeDataLoader<IMultiStreamSource, ITransformer> composite)
149190
{

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5648,7 +5648,6 @@ public void LoadEntryPointModel()
56485648
{
56495649
loadedModel = ml.Model.Load(stream, out DataViewSchema inputSchema);
56505650
}
5651-
56525651
}
56535652
}
56545653
}

0 commit comments

Comments
 (0)